Codebase list matrix-synapse / 10a991c
New upstream version 0.27.2+dfsg Andrew Shadura 6 years ago
271 changed file(s) with 13161 addition(s) and 6629 deletion(s). Raw diff Collapse all Expand all
4545
4646 env/
4747 *.config
48
49 .vscode/
0 Changes in synapse v0.27.2 (2018-03-26)
1 =======================================
2
3 Bug fixes:
4
5 * Fix bug which broke TCP replication between workers (PR #3015)
6
7
8 Changes in synapse v0.27.1 (2018-03-26)
9 =======================================
10
11 Meta release as v0.27.0 temporarily pointed to the wrong commit
12
13
14 Changes in synapse v0.27.0 (2018-03-26)
15 =======================================
16
17 No changes since v0.27.0-rc2
18
19
20 Changes in synapse v0.27.0-rc2 (2018-03-19)
21 ===========================================
22
23 Pulls in v0.26.1
24
25 Bug fixes:
26
27 * Fix bug introduced in v0.27.0-rc1 that causes much increased memory usage in state cache (PR #3005)
28
29
30 Changes in synapse v0.26.1 (2018-03-15)
31 =======================================
32
33 Bug fixes:
34
35 * Fix bug where an invalid event caused server to stop functioning correctly,
36 due to parsing and serializing bugs in ujson library (PR #3008)
37
38
39 Changes in synapse v0.27.0-rc1 (2018-03-14)
40 ===========================================
41
42 The common case for running Synapse is not to run separate workers, but for those that do, be aware that synctl no longer starts the main synapse when using ``-a`` option with workers. A new worker file should be added with ``worker_app: synapse.app.homeserver``.
43
44 This release also begins the process of renaming a number of the metrics
45 reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
46 Note that the v0.28.0 release will remove the deprecated metric names.
47
48 Features:
49
50 * Add ability for ASes to override message send time (PR #2754)
51 * Add support for custom storage providers for media repository (PR #2867, #2777, #2783, #2789, #2791, #2804, #2812, #2814, #2857, #2868, #2767)
52 * Add purge API features, see `docs/admin_api/purge_history_api.rst <docs/admin_api/purge_history_api.rst>`_ for full details (PR #2858, #2867, #2882, #2946, #2962, #2943)
53 * Add support for whitelisting 3PIDs that users can register. (PR #2813)
54 * Add ``/room/{id}/event/{id}`` API (PR #2766)
55 * Add an admin API to get all the media in a room (PR #2818) Thanks to @turt2live!
56 * Add ``federation_domain_whitelist`` option (PR #2820, #2821)
57
58
59 Changes:
60
61 * Continue to factor out processing from main process and into worker processes. See updated `docs/workers.rst <docs/metrics-howto.rst>`_ (PR #2892 - #2904, #2913, #2920 - #2926, #2947, #2847, #2854, #2872, #2873, #2874, #2928, #2929, #2934, #2856, #2976 - #2984, #2987 - #2989, #2991 - #2993, #2995, #2784)
62 * Ensure state cache is used when persisting events (PR #2864, #2871, #2802, #2835, #2836, #2841, #2842, #2849)
63 * Change the default config to bind on both IPv4 and IPv6 on all platforms (PR #2435) Thanks to @silkeh!
64 * No longer require a specific version of saml2 (PR #2695) Thanks to @okurz!
65 * Remove ``verbosity``/``log_file`` from generated config (PR #2755)
66 * Add and improve metrics and logging (PR #2770, #2778, #2785, #2786, #2787, #2793, #2794, #2795, #2809, #2810, #2833, #2834, #2844, #2965, #2927, #2975, #2790, #2796, #2838)
67 * When using synctl with workers, don't start the main synapse automatically (PR #2774)
68 * Minor performance improvements (PR #2773, #2792)
69 * Use a connection pool for non-federation outbound connections (PR #2817)
70 * Make it possible to run unit tests against postgres (PR #2829)
71 * Update pynacl dependency to 1.2.1 or higher (PR #2888) Thanks to @bachp!
72 * Remove ability for AS users to call /events and /sync (PR #2948)
73 * Use bcrypt.checkpw (PR #2949) Thanks to @krombel!
74
75 Bug fixes:
76
77 * Fix broken ``ldap_config`` config option (PR #2683) Thanks to @seckrv!
78 * Fix error message when user is not allowed to unban (PR #2761) Thanks to @turt2live!
79 * Fix publicised groups GET API (singular) over federation (PR #2772)
80 * Fix user directory when using ``user_directory_search_all_users`` config option (PR #2803, #2831)
81 * Fix error on ``/publicRooms`` when no rooms exist (PR #2827)
82 * Fix bug in quarantine_media (PR #2837)
83 * Fix url_previews when no Content-Type is returned from URL (PR #2845)
84 * Fix rare race in sync API when joining room (PR #2944)
85 * Fix slow event search, switch back from GIST to GIN indexes (PR #2769, #2848)
86
87
88 Changes in synapse v0.26.0 (2018-01-05)
89 =======================================
90
91 No changes since v0.26.0-rc1
92
93
94 Changes in synapse v0.26.0-rc1 (2017-12-13)
95 ===========================================
96
97 Features:
98
99 * Add ability for ASes to publicise groups for their users (PR #2686)
100 * Add all local users to the user_directory and optionally search them (PR
101 #2723)
102 * Add support for custom login types for validating users (PR #2729)
103
104
105 Changes:
106
107 * Update example Prometheus config to new format (PR #2648) Thanks to
108 @krombel!
109 * Rename redact_content option to include_content in Push API (PR #2650)
110 * Declare support for r0.3.0 (PR #2677)
111 * Improve upserts (PR #2684, #2688, #2689, #2713)
112 * Improve documentation of workers (PR #2700)
113 * Improve tracebacks on exceptions (PR #2705)
114 * Allow guest access to group APIs for reading (PR #2715)
115 * Support for posting content in federation_client script (PR #2716)
116 * Delete devices and pushers on logouts etc (PR #2722)
117
118
119 Bug fixes:
120
121 * Fix database port script (PR #2673)
122 * Fix internal server error on login with ldap_auth_provider (PR #2678) Thanks
123 to @jkolo!
124 * Fix error on sqlite 3.7 (PR #2697)
125 * Fix OPTIONS on preview_url (PR #2707)
126 * Fix error handling on dns lookup (PR #2711)
127 * Fix wrong avatars when inviting multiple users when creating room (PR #2717)
128 * Fix 500 when joining matrix-dev (PR #2719)
129
130
131 Changes in synapse v0.25.1 (2017-11-17)
132 =======================================
133
134 Bug fixes:
135
136 * Fix login with LDAP and other password provider modules (PR #2678). Thanks to
137 @jkolo!
138
139 Changes in synapse v0.25.0 (2017-11-15)
140 =======================================
141
142 Bug fixes:
143
144 * Fix port script (PR #2673)
145
146
147 Changes in synapse v0.25.0-rc1 (2017-11-14)
148 ===========================================
149
150 Features:
151
152 * Add is_public to groups table to allow for private groups (PR #2582)
153 * Add a route for determining who you are (PR #2668) Thanks to @turt2live!
154 * Add more features to the password providers (PR #2608, #2610, #2620, #2622,
155 #2623, #2624, #2626, #2628, #2629)
156 * Add a hook for custom rest endpoints (PR #2627)
157 * Add API to update group room visibility (PR #2651)
158
159
160 Changes:
161
162 * Ignore <noscript> tags when generating URL preview descriptions (PR #2576)
163 Thanks to @maximevaillancourt!
164 * Register some /unstable endpoints in /r0 as well (PR #2579) Thanks to
165 @krombel!
166 * Support /keys/upload on /r0 as well as /unstable (PR #2585)
167 * Front-end proxy: pass through auth header (PR #2586)
168 * Allow ASes to deactivate their own users (PR #2589)
169 * Remove refresh tokens (PR #2613)
170 * Automatically set default displayname on register (PR #2617)
171 * Log login requests (PR #2618)
172 * Always return `is_public` in the `/groups/:group_id/rooms` API (PR #2630)
173 * Avoid no-op media deletes (PR #2637) Thanks to @spantaleev!
174 * Fix various embarrassing typos around user_directory and add some doc. (PR
175 #2643)
176 * Return whether a user is an admin within a group (PR #2647)
177 * Namespace visibility options for groups (PR #2657)
178 * Downcase UserIDs on registration (PR #2662)
179 * Cache failures when fetching URL previews (PR #2669)
180
181
182 Bug fixes:
183
184 * Fix port script (PR #2577)
185 * Fix error when running synapse with no logfile (PR #2581)
186 * Fix UI auth when deleting devices (PR #2591)
187 * Fix typo when checking if user is invited to group (PR #2599)
188 * Fix the port script to drop NUL values in all tables (PR #2611)
189 * Fix appservices being backlogged and not receiving new events due to a bug in
190 notify_interested_services (PR #2631) Thanks to @xyzz!
191 * Fix updating rooms avatar/display name when modified by admin (PR #2636)
192 Thanks to @farialima!
193 * Fix bug in state group storage (PR #2649)
194 * Fix 500 on invalid utf-8 in request (PR #2663)
195
196
197 Changes in synapse v0.24.1 (2017-10-24)
198 =======================================
199
200 Bug fixes:
201
202 * Fix updating group profiles over federation (PR #2567)
203
204
0205 Changes in synapse v0.24.0 (2017-10-23)
1206 =======================================
2207
2929 you to make any refinements needed or merge it and make them ourselves. The
3030 changes will then land on master when we next do a release.
3131
32 We use Jenkins for continuous integration (http://matrix.org/jenkins), and
33 typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open.
32 We use `Jenkins <http://matrix.org/jenkins>`_ and
33 `Travis <https://travis-ci.org/matrix-org/synapse>`_ for continuous
34 integration. All pull requests to synapse get automatically tested by Travis;
35 the Jenkins builds require an adminstrator to start them. If your change
36 breaks the build, this will be shown in github, so please keep an eye on the
37 pull request for feedback.
3438
3539 Code style
3640 ~~~~~~~~~~
114118 Conclusion
115119 ~~~~~~~~~~
116120
117 That's it! Matrix is a very open and collaborative project as you might expect given our obsession with open communication. If we're going to successfully matrix together all the fragmented communication technologies out there we are reliant on contributions and collaboration from the community to do so. So please get involved - and we hope you have as much fun hacking on Matrix as we do!
121 That's it! Matrix is a very open and collaborative project as you might expect given our obsession with open communication. If we're going to successfully matrix together all the fragmented communication technologies out there we are reliant on contributions and collaboration from the community to do so. So please get involved - and we hope you have as much fun hacking on Matrix as we do!
352352
353353 Fedora
354354 ------
355
356 Synapse is in the Fedora repositories as ``matrix-synapse``::
357
358 sudo dnf install matrix-synapse
355359
356360 Oleg Girko provides Fedora RPMs at
357361 https://obs.infoserver.lv/project/monitor/matrix-synapse
631635
632636 Troubleshooting
633637 ---------------
638
639 You can use the federation tester to check if your homeserver is all set:
640 ``https://matrix.org/federationtester/api/report?server_name=<your_server_name>``
641 If any of the attributes under "checks" is false, federation won't work.
642
634643 The typical failure mode with federation is that when you try to join a room,
635644 it is rejected with "401: Unauthorized". Generally this means that other
636645 servers in the room couldn't access yours. (Joining a room over federation is a
822831 your loopback and RFC1918 IP addresses are blacklisted.
823832
824833 This also requires the optional lxml and netaddr python dependencies to be
825 installed.
834 installed. This in turn requires the libxml2 library to be available - on
835 Debian/Ubuntu this means ``apt-get install libxml2-dev``, or equivalent for
836 your OS.
826837
827838
828839 Password reset
882893
883894 PASSED (successes=143)
884895
896 Running the Integration Tests
897 =============================
898
899 Synapse is accompanied by `SyTest <https://github.com/matrix-org/sytest>`_,
900 a Matrix homeserver integration testing suite, which uses HTTP requests to
901 access the API as a Matrix client would. It is able to run Synapse directly from
902 the source tree, so installation of the server is not required.
903
904 Testing with SyTest is recommended for verifying that changes related to the
905 Client-Server API are functioning correctly. See the `installation instructions
906 <https://github.com/matrix-org/sytest#installing>`_ for details.
885907
886908 Building Internal API Documentation
887909 ===================================
44
55 http://prometheus.io/
66
7 Then add a new job to the main prometheus.conf file:
7 ### for Prometheus v1
8 Add a new job to the main prometheus.conf file:
89
910 job: {
1011 name: "synapse"
1415 }
1516 }
1617
18 ### for Prometheus v2
19 Add a new job to the main prometheus.yml file:
20
21 - job_name: "synapse"
22 metrics_path: "/_synapse/metrics"
23 # when endpoint uses https:
24 scheme: "https"
25
26 static_configs:
27 - targets: ['SERVER.LOCATION:PORT']
28
29 To use `synapse.rules` add
30
31 rule_files:
32 - "/PATH/TO/synapse-v2.rules"
33
1734 Metrics are disabled by default when running synapse; they must be enabled
1835 with the 'enable-metrics' option, either in the synapse config file or as a
1936 command-line option.
0 synapse_federation_transaction_queue_pendingEdus:total = sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)
1 synapse_federation_transaction_queue_pendingPdus:total = sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)
2
3 synapse_http_server_requests:method{servlet=""} = sum(synapse_http_server_requests) by (method)
4 synapse_http_server_requests:servlet{method=""} = sum(synapse_http_server_requests) by (servlet)
5
6 synapse_http_server_requests:total{servlet=""} = sum(synapse_http_server_requests:by_method) by (servlet)
7
8 synapse_cache:hit_ratio_5m = rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])
9 synapse_cache:hit_ratio_30s = rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])
10
11 synapse_federation_client_sent{type="EDU"} = synapse_federation_client_sent_edus + 0
12 synapse_federation_client_sent{type="PDU"} = synapse_federation_client_sent_pdu_destinations:count + 0
13 synapse_federation_client_sent{type="Query"} = sum(synapse_federation_client_sent_queries) by (job)
14
15 synapse_federation_server_received{type="EDU"} = synapse_federation_server_received_edus + 0
16 synapse_federation_server_received{type="PDU"} = synapse_federation_server_received_pdus + 0
17 synapse_federation_server_received{type="Query"} = sum(synapse_federation_server_received_queries) by (job)
18
19 synapse_federation_transaction_queue_pending{type="EDU"} = synapse_federation_transaction_queue_pending_edus + 0
20 synapse_federation_transaction_queue_pending{type="PDU"} = synapse_federation_transaction_queue_pending_pdus + 0
0 groups:
1 - name: synapse
2 rules:
3 - record: "synapse_federation_transaction_queue_pendingEdus:total"
4 expr: "sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)"
5 - record: "synapse_federation_transaction_queue_pendingPdus:total"
6 expr: "sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)"
7 - record: 'synapse_http_server_requests:method'
8 labels:
9 servlet: ""
10 expr: "sum(synapse_http_server_requests) by (method)"
11 - record: 'synapse_http_server_requests:servlet'
12 labels:
13 method: ""
14 expr: 'sum(synapse_http_server_requests) by (servlet)'
15
16 - record: 'synapse_http_server_requests:total'
17 labels:
18 servlet: ""
19 expr: 'sum(synapse_http_server_requests:by_method) by (servlet)'
20
21 - record: 'synapse_cache:hit_ratio_5m'
22 expr: 'rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])'
23 - record: 'synapse_cache:hit_ratio_30s'
24 expr: 'rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])'
25
26 - record: 'synapse_federation_client_sent'
27 labels:
28 type: "EDU"
29 expr: 'synapse_federation_client_sent_edus + 0'
30 - record: 'synapse_federation_client_sent'
31 labels:
32 type: "PDU"
33 expr: 'synapse_federation_client_sent_pdu_destinations:count + 0'
34 - record: 'synapse_federation_client_sent'
35 labels:
36 type: "Query"
37 expr: 'sum(synapse_federation_client_sent_queries) by (job)'
38
39 - record: 'synapse_federation_server_received'
40 labels:
41 type: "EDU"
42 expr: 'synapse_federation_server_received_edus + 0'
43 - record: 'synapse_federation_server_received'
44 labels:
45 type: "PDU"
46 expr: 'synapse_federation_server_received_pdus + 0'
47 - record: 'synapse_federation_server_received'
48 labels:
49 type: "Query"
50 expr: 'sum(synapse_federation_server_received_queries) by (job)'
51
52 - record: 'synapse_federation_transaction_queue_pending'
53 labels:
54 type: "EDU"
55 expr: 'synapse_federation_transaction_queue_pending_edus + 0'
56 - record: 'synapse_federation_transaction_queue_pending'
57 labels:
58 type: "PDU"
59 expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
+0
-21
contrib/prometheus/synapse.rules less more
0 synapse_federation_transaction_queue_pendingEdus:total = sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)
1 synapse_federation_transaction_queue_pendingPdus:total = sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)
2
3 synapse_http_server_requests:method{servlet=""} = sum(synapse_http_server_requests) by (method)
4 synapse_http_server_requests:servlet{method=""} = sum(synapse_http_server_requests) by (servlet)
5
6 synapse_http_server_requests:total{servlet=""} = sum(synapse_http_server_requests:by_method) by (servlet)
7
8 synapse_cache:hit_ratio_5m = rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])
9 synapse_cache:hit_ratio_30s = rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])
10
11 synapse_federation_client_sent{type="EDU"} = synapse_federation_client_sent_edus + 0
12 synapse_federation_client_sent{type="PDU"} = synapse_federation_client_sent_pdu_destinations:count + 0
13 synapse_federation_client_sent{type="Query"} = sum(synapse_federation_client_sent_queries) by (job)
14
15 synapse_federation_server_received{type="EDU"} = synapse_federation_server_received_edus + 0
16 synapse_federation_server_received{type="PDU"} = synapse_federation_server_received_pdus + 0
17 synapse_federation_server_received{type="Query"} = sum(synapse_federation_server_received_queries) by (job)
18
19 synapse_federation_transaction_queue_pending{type="EDU"} = synapse_federation_transaction_queue_pending_edus + 0
20 synapse_federation_transaction_queue_pending{type="PDU"} = synapse_federation_transaction_queue_pending_pdus + 0
0 # List all media in a room
1
2 This API gets a list of known media in a room.
3
4 The API is:
5 ```
6 GET /_matrix/client/r0/admin/room/<room_id>/media
7 ```
8 including an `access_token` of a server admin.
9
10 It returns a JSON body like the following:
11 ```
12 {
13 "local": [
14 "mxc://localhost/xwvutsrqponmlkjihgfedcba",
15 "mxc://localhost/abcdefghijklmnopqrstuvwx"
16 ],
17 "remote": [
18 "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
19 "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
20 ]
21 }
22 ```
33 The purge history API allows server admins to purge historic events from their
44 database, reclaiming disk space.
55
6 **NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
7
86 Depending on the amount of history being purged a call to the API may take
97 several minutes or longer. During this period users will not be able to
108 paginate further back in the room from the point being purged from.
119
12 The API is simply:
10 The API is:
1311
14 ``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
12 ``POST /_matrix/client/r0/admin/purge_history/<room_id>[/<event_id>]``
1513
1614 including an ``access_token`` of a server admin.
15
16 By default, events sent by local users are not deleted, as they may represent
17 the only copies of this content in existence. (Events sent by remote users are
18 deleted.)
19
20 Room state data (such as joins, leaves, topic) is always preserved.
21
22 To delete local message events as well, set ``delete_local_events`` in the body:
23
24 .. code:: json
25
26 {
27 "delete_local_events": true
28 }
29
30 The caller must specify the point in the room to purge up to. This can be
31 specified by including an event_id in the URI, or by setting a
32 ``purge_up_to_event_id`` or ``purge_up_to_ts`` in the request body. If an event
33 id is given, that event (and others at the same graph depth) will be retained.
34 If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch,
35 in milliseconds.
36
37 The API starts the purge running, and returns immediately with a JSON body with
38 a purge id:
39
40 .. code:: json
41
42 {
43 "purge_id": "<opaque id>"
44 }
45
46 Purge status query
47 ------------------
48
49 It is possible to poll for updates on recent purges with a second API;
50
51 ``GET /_matrix/client/r0/admin/purge_history_status/<purge_id>``
52
53 (again, with a suitable ``access_token``). This API returns a JSON body like
54 the following:
55
56 .. code:: json
57
58 {
59 "status": "active"
60 }
61
62 The status will be one of ``active``, ``complete``, or ``failed``.
0 Basically, PEP8
0 - Everything should comply with PEP8. Code should pass
1 ``pep8 --max-line-length=100`` without any warnings.
12
2 - NEVER tabs. 4 spaces to indent.
3 - Max line width: 79 chars (with flexibility to overflow by a "few chars" if
3 - **Indenting**:
4
5 - NEVER tabs. 4 spaces to indent.
6
7 - follow PEP8; either hanging indent or multiline-visual indent depending
8 on the size and shape of the arguments and what makes more sense to the
9 author. In other words, both this::
10
11 print("I am a fish %s" % "moo")
12
13 and this::
14
15 print("I am a fish %s" %
16 "moo")
17
18 and this::
19
20 print(
21 "I am a fish %s" %
22 "moo",
23 )
24
25 ...are valid, although given each one takes up 2x more vertical space than
26 the previous, it's up to the author's discretion as to which layout makes
27 most sense for their function invocation. (e.g. if they want to add
28 comments per-argument, or put expressions in the arguments, or group
29 related arguments together, or want to deliberately extend or preserve
30 vertical/horizontal space)
31
32 - **Line length**:
33
34 Max line length is 79 chars (with flexibility to overflow by a "few chars" if
435 the overflowing content is not semantically significant and avoids an
536 explosion of vertical whitespace).
6 - Use camel case for class and type names
7 - Use underscores for functions and variables.
8 - Use double quotes.
9 - Use parentheses instead of '\\' for line continuation where ever possible
10 (which is pretty much everywhere)
11 - There should be max a single new line between:
37
38 Use parentheses instead of ``\`` for line continuation where ever possible
39 (which is pretty much everywhere).
40
41 - **Naming**:
42
43 - Use camel case for class and type names
44 - Use underscores for functions and variables.
45
46 - Use double quotes ``"foo"`` rather than single quotes ``'foo'``.
47
48 - **Blank lines**:
49
50 - There should be max a single new line between:
51
1252 - statements
1353 - functions in a class
14 - There should be two new lines between:
54
55 - There should be two new lines between:
56
1557 - definitions in a module (e.g., between different classes)
16 - There should be spaces where spaces should be and not where there shouldn't be:
17 - a single space after a comma
18 - a single space before and after for '=' when used as assignment
19 - no spaces before and after for '=' for default values and keyword arguments.
20 - Indenting must follow PEP8; either hanging indent or multiline-visual indent
21 depending on the size and shape of the arguments and what makes more sense to
22 the author. In other words, both this::
2358
24 print("I am a fish %s" % "moo")
59 - **Whitespace**:
2560
26 and this::
61 There should be spaces where spaces should be and not where there shouldn't
62 be:
2763
28 print("I am a fish %s" %
29 "moo")
64 - a single space after a comma
65 - a single space before and after for '=' when used as assignment
66 - no spaces before and after for '=' for default values and keyword arguments.
3067
31 and this::
68 - **Comments**: should follow the `google code style
69 <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
70 This is so that we can generate documentation with `sphinx
71 <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
72 `examples
73 <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
74 in the sphinx documentation.
3275
33 print(
34 "I am a fish %s" %
35 "moo"
36 )
76 - **Imports**:
3777
38 ...are valid, although given each one takes up 2x more vertical space than
39 the previous, it's up to the author's discretion as to which layout makes most
40 sense for their function invocation. (e.g. if they want to add comments
41 per-argument, or put expressions in the arguments, or group related arguments
42 together, or want to deliberately extend or preserve vertical/horizontal
43 space)
78 - Prefer to import classes and functions than packages or modules.
4479
45 Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
46 This is so that we can generate documentation with
47 `sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
48 `examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
49 in the sphinx documentation.
80 Example::
5081
51 Code should pass pep8 --max-line-length=100 without any warnings.
82 from synapse.types import UserID
83 ...
84 user_id = UserID(local, server)
85
86 is preferred over::
87
88 from synapse import types
89 ...
90 user_id = types.UserID(local, server)
91
92 (or any other variant).
93
94 This goes against the advice in the Google style guide, but it means that
95 errors in the name are caught early (at import time).
96
97 - Multiple imports from the same package can be combined onto one line::
98
99 from synapse.types import GroupID, RoomID, UserID
100
101 An effort should be made to keep the individual imports in alphabetical
102 order.
103
104 If the list becomes long, wrap it with parentheses and split it over
105 multiple lines.
106
107 - As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_,
108 imports should be grouped in the following order, with a blank line between
109 each group:
110
111 1. standard library imports
112 2. related third party imports
113 3. local application/library specific imports
114
115 - Imports within each group should be sorted alphabetically by module name.
116
117 - Avoid wildcard imports (``from synapse.types import *``) and relative
118 imports (``from .types import UserID``).
278278 that might be fixed by setting a different logcontext via a ``with
279279 LoggingContext(...)`` in ``background_operation``).
280280
281 The second option is to use ``logcontext.preserve_fn``, which wraps a function
282 so that it doesn't reset the logcontext even when it returns an incomplete
283 deferred, and adds a callback to the returned deferred to reset the
281 The second option is to use ``logcontext.run_in_background``, which wraps a
282 function so that it doesn't reset the logcontext even when it returns an
283 incomplete deferred, and adds a callback to the returned deferred to reset the
284284 logcontext. In other words, it turns a function that follows the Synapse rules
285285 about logcontexts and Deferreds into one which behaves more like an external
286286 function — the opposite operation to that described in the previous section.
292292 def do_request_handling():
293293 yield foreground_operation()
294294
295 logcontext.preserve_fn(background_operation)()
295 logcontext.run_in_background(background_operation)
296296
297297 # this will now be logged against the request context
298298 logger.debug("Request handling complete")
299
300 XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
301 but the fact that it does ``preserve_context_over_deferred`` on its results
302 means that its use is fraught with difficulty.
303299
304300 Passing synapse deferreds into third-party functions
305301 ----------------------------------------------------
1515 metrics_port: 9092
1616
1717 Also ensure that ``enable_metrics`` is set to ``True``.
18
18
1919 Restart synapse.
2020
2121 3. Add a prometheus target for synapse.
2727 static_configs:
2828 - targets: ["my.server.here:9092"]
2929
30 If your prometheus is older than 1.5.2, you will need to replace
30 If your prometheus is older than 1.5.2, you will need to replace
3131 ``static_configs`` in the above with ``target_groups``.
32
32
3333 Restart prometheus.
34
35
36 Block and response metrics renamed for 0.27.0
37 ---------------------------------------------
38
39 Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
40 metrics reported for the resource tracking for code blocks and HTTP requests.
41
42 At the same time, the corresponding ``*:total`` metrics are being renamed, as
43 the ``:total`` suffix no longer makes sense in the absence of a corresponding
44 ``:count`` metric.
45
46 To enable a graceful migration path, this release just adds new names for the
47 metrics being renamed. A future release will remove the old ones.
48
49 The following table shows the new metrics, and the old metrics which they are
50 replacing.
51
52 ==================================================== ===================================================
53 New name Old name
54 ==================================================== ===================================================
55 synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
56 synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
57 synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
58 synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
59 synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
60
61 synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
62 synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
63 synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
64 synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
65 synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
66
67 synapse_http_server_response_count synapse_http_server_requests
68 synapse_http_server_response_count synapse_http_server_response_time:count
69 synapse_http_server_response_count synapse_http_server_response_ru_utime:count
70 synapse_http_server_response_count synapse_http_server_response_ru_stime:count
71 synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
72 synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
73
74 synapse_http_server_response_time_seconds synapse_http_server_response_time:total
75 synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
76 synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
77 synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
78 synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
79 ==================================================== ===================================================
80
3481
3582 Standard Metric Names
3683 ---------------------
4188
4289 ================================== =============================
4390 New name Old name
44 ---------------------------------- -----------------------------
91 ================================== =============================
4592 process_cpu_user_seconds_total process_resource_utime / 1000
4693 process_cpu_system_seconds_total process_resource_stime / 1000
4794 process_open_fds (no 'type' label) process_fds
5198
5299 =========================== ======================
53100 New name Old name
54 --------------------------- ----------------------
55 python_gc_time reactor_gc_time
101 =========================== ======================
102 python_gc_time reactor_gc_time
56103 python_gc_unreachable_total reactor_gc_unreachable
57104 python_gc_counts reactor_gc_counts
58105 =========================== ======================
61108
62109 ==================================== =====================
63110 New name Old name
64 ------------------------------------ ---------------------
111 ==================================== =====================
65112 python_twisted_reactor_pending_calls reactor_pending_calls
66113 python_twisted_reactor_tick_time reactor_tick_time
67114 ==================================== =====================
0 Password auth provider modules
1 ==============================
2
3 Password auth providers offer a way for server administrators to integrate
4 their Synapse installation with an existing authentication system.
5
6 A password auth provider is a Python class which is dynamically loaded into
7 Synapse, and provides a number of methods by which it can integrate with the
8 authentication system.
9
10 This document serves as a reference for those looking to implement their own
11 password auth providers.
12
13 Required methods
14 ----------------
15
16 Password auth provider classes must provide the following methods:
17
18 *class* ``SomeProvider.parse_config``\(*config*)
19
20 This method is passed the ``config`` object for this module from the
21 homeserver configuration file.
22
23 It should perform any appropriate sanity checks on the provided
24 configuration, and return an object which is then passed into ``__init__``.
25
26 *class* ``SomeProvider``\(*config*, *account_handler*)
27
28 The constructor is passed the config object returned by ``parse_config``,
29 and a ``synapse.module_api.ModuleApi`` object which allows the
30 password provider to check if accounts exist and/or create new ones.
31
32 Optional methods
33 ----------------
34
35 Password auth provider classes may optionally provide the following methods.
36
37 *class* ``SomeProvider.get_db_schema_files``\()
38
39 This method, if implemented, should return an Iterable of ``(name,
40 stream)`` pairs of database schema files. Each file is applied in turn at
41 initialisation, and a record is then made in the database so that it is
42 not re-applied on the next start.
43
44 ``someprovider.get_supported_login_types``\()
45
46 This method, if implemented, should return a ``dict`` mapping from a login
47 type identifier (such as ``m.login.password``) to an iterable giving the
48 fields which must be provided by the user in the submission to the
49 ``/login`` api. These fields are passed in the ``login_dict`` dictionary
50 to ``check_auth``.
51
52 For example, if a password auth provider wants to implement a custom login
53 type of ``com.example.custom_login``, where the client is expected to pass
54 the fields ``secret1`` and ``secret2``, the provider should implement this
55 method and return the following dict::
56
57 {"com.example.custom_login": ("secret1", "secret2")}
58
59 ``someprovider.check_auth``\(*username*, *login_type*, *login_dict*)
60
61 This method is the one that does the real work. If implemented, it will be
62 called for each login attempt where the login type matches one of the keys
63 returned by ``get_supported_login_types``.
64
65 It is passed the (possibly UNqualified) ``user`` provided by the client,
66 the login type, and a dictionary of login secrets passed by the client.
67
68 The method should return a Twisted ``Deferred`` object, which resolves to
69 the canonical ``@localpart:domain`` user id if authentication is successful,
70 and ``None`` if not.
71
72 Alternatively, the ``Deferred`` can resolve to a ``(str, func)`` tuple, in
73 which case the second field is a callback which will be called with the
74 result from the ``/login`` call (including ``access_token``, ``device_id``,
75 etc.)
76
77 ``someprovider.check_password``\(*user_id*, *password*)
78
79 This method provides a simpler interface than ``get_supported_login_types``
80 and ``check_auth`` for password auth providers that just want to provide a
81 mechanism for validating ``m.login.password`` logins.
82
83 Iif implemented, it will be called to check logins with an
84 ``m.login.password`` login type. It is passed a qualified
85 ``@localpart:domain`` user id, and the password provided by the user.
86
87 The method should return a Twisted ``Deferred`` object, which resolves to
88 ``True`` if authentication is successful, and ``False`` if not.
89
90 ``someprovider.on_logged_out``\(*user_id*, *device_id*, *access_token*)
91
92 This method, if implemented, is called when a user logs out. It is passed
93 the qualified user ID, the ID of the deactivated device (if any: access
94 tokens are occasionally created without an associated device ID), and the
95 (now deactivated) access token.
96
97 It may return a Twisted ``Deferred`` object; the logout request will wait
98 for the deferred to complete but the result is ignored.
0 URL Previews
1 ============
2
3 Design notes on a URL previewing service for Matrix:
4
5 Options are:
6
7 1. Have an AS which listens for URLs, downloads them, and inserts an event that describes their metadata.
8 * Pros:
9 * Decouples the implementation entirely from Synapse.
10 * Uses existing Matrix events & content repo to store the metadata.
11 * Cons:
12 * Which AS should provide this service for a room, and why should you trust it?
13 * Doesn't work well with E2E; you'd have to cut the AS into every room
14 * the AS would end up subscribing to every room anyway.
15
16 2. Have a generic preview API (nothing to do with Matrix) that provides a previewing service:
17 * Pros:
18 * Simple and flexible; can be used by any clients at any point
19 * Cons:
20 * If each HS provides one of these independently, all the HSes in a room may needlessly DoS the target URI
21 * We need somewhere to store the URL metadata rather than just using Matrix itself
22 * We can't piggyback on matrix to distribute the metadata between HSes.
23
24 3. Make the synapse of the sending user responsible for spidering the URL and inserting an event asynchronously which describes the metadata.
25 * Pros:
26 * Works transparently for all clients
27 * Piggy-backs nicely on using Matrix for distributing the metadata.
28 * No confusion as to which AS
29 * Cons:
30 * Doesn't work with E2E
31 * We might want to decouple the implementation of the spider from the HS, given spider behaviour can be quite complicated and evolve much more rapidly than the HS. It's more like a bot than a core part of the server.
32
33 4. Make the sending client use the preview API and insert the event itself when successful.
34 * Pros:
35 * Works well with E2E
36 * No custom server functionality
37 * Lets the client customise the preview that they send (like on FB)
38 * Cons:
39 * Entirely specific to the sending client, whereas it'd be nice if /any/ URL was correctly previewed if clients support it.
40
41 5. Have the option of specifying a shared (centralised) previewing service used by a room, to avoid all the different HSes in the room DoSing the target.
42
43 Best solution is probably a combination of both 2 and 4.
44 * Sending clients do their best to create and send a preview at the point of sending the message, perhaps delaying the message until the preview is computed? (This also lets the user validate the preview before sending)
45 * Receiving clients have the option of going and creating their own preview if one doesn't arrive soon enough (or if the original sender didn't create one)
46
47 This is a bit magical though in that the preview could come from two entirely different sources - the sending HS or your local one. However, this can always be exposed to users: "Generate your own URL previews if none are available?"
48
49 This is tantamount also to senders calculating their own thumbnails for sending in advance of the main content - we are trusting the sender not to lie about the content in the thumbnail. Whereas currently thumbnails are calculated by the receiving homeserver to avoid this attack.
50
51 However, this kind of phishing attack does exist whether we let senders pick their thumbnails or not, in that a malicious sender can send normal text messages around the attachment claiming it to be legitimate. We could rely on (future) reputation/abuse management to punish users who phish (be it with bogus metadata or bogus descriptions). Bogus metadata is particularly bad though, especially if it's avoidable.
52
53 As a first cut, let's do #2 and have the receiver hit the API to calculate its own previews (as it does currently for image thumbnails). We can then extend/optimise this to option 4 as a special extra if needed.
54
55 API
56 ---
57
58 ```
59 GET /_matrix/media/r0/preview_url?url=http://wherever.com
60 200 OK
61 {
62 "og:type" : "article"
63 "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672"
64 "og:title" : "Matrix on Twitter"
65 "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png"
66 "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
67 "og:site_name" : "Twitter"
68 }
69 ```
70
71 * Downloads the URL
72 * If HTML, just stores it in RAM and parses it for OG meta tags
73 * Download any media OG meta tags to the media repo, and refer to them in the OG via mxc:// URIs.
74 * If a media filetype we know we can thumbnail: store it on disk, and hand it to the thumbnailer. Generate OG meta tags from the thumbnailer contents.
75 * Otherwise, don't bother downloading further.
+0
-74
docs/url_previews.rst less more
0 URL Previews
1 ============
2
3 Design notes on a URL previewing service for Matrix:
4
5 Options are:
6
7 1. Have an AS which listens for URLs, downloads them, and inserts an event that describes their metadata.
8 * Pros:
9 * Decouples the implementation entirely from Synapse.
10 * Uses existing Matrix events & content repo to store the metadata.
11 * Cons:
12 * Which AS should provide this service for a room, and why should you trust it?
13 * Doesn't work well with E2E; you'd have to cut the AS into every room
14 * the AS would end up subscribing to every room anyway.
15
16 2. Have a generic preview API (nothing to do with Matrix) that provides a previewing service:
17 * Pros:
18 * Simple and flexible; can be used by any clients at any point
19 * Cons:
20 * If each HS provides one of these independently, all the HSes in a room may needlessly DoS the target URI
21 * We need somewhere to store the URL metadata rather than just using Matrix itself
22 * We can't piggyback on matrix to distribute the metadata between HSes.
23
24 3. Make the synapse of the sending user responsible for spidering the URL and inserting an event asynchronously which describes the metadata.
25 * Pros:
26 * Works transparently for all clients
27 * Piggy-backs nicely on using Matrix for distributing the metadata.
28 * No confusion as to which AS
29 * Cons:
30 * Doesn't work with E2E
31 * We might want to decouple the implementation of the spider from the HS, given spider behaviour can be quite complicated and evolve much more rapidly than the HS. It's more like a bot than a core part of the server.
32
33 4. Make the sending client use the preview API and insert the event itself when successful.
34 * Pros:
35 * Works well with E2E
36 * No custom server functionality
37 * Lets the client customise the preview that they send (like on FB)
38 * Cons:
39 * Entirely specific to the sending client, whereas it'd be nice if /any/ URL was correctly previewed if clients support it.
40
41 5. Have the option of specifying a shared (centralised) previewing service used by a room, to avoid all the different HSes in the room DoSing the target.
42
43 Best solution is probably a combination of both 2 and 4.
44 * Sending clients do their best to create and send a preview at the point of sending the message, perhaps delaying the message until the preview is computed? (This also lets the user validate the preview before sending)
45 * Receiving clients have the option of going and creating their own preview if one doesn't arrive soon enough (or if the original sender didn't create one)
46
47 This is a bit magical though in that the preview could come from two entirely different sources - the sending HS or your local one. However, this can always be exposed to users: "Generate your own URL previews if none are available?"
48
49 This is tantamount also to senders calculating their own thumbnails for sending in advance of the main content - we are trusting the sender not to lie about the content in the thumbnail. Whereas currently thumbnails are calculated by the receiving homeserver to avoid this attack.
50
51 However, this kind of phishing attack does exist whether we let senders pick their thumbnails or not, in that a malicious sender can send normal text messages around the attachment claiming it to be legitimate. We could rely on (future) reputation/abuse management to punish users who phish (be it with bogus metadata or bogus descriptions). Bogus metadata is particularly bad though, especially if it's avoidable.
52
53 As a first cut, let's do #2 and have the receiver hit the API to calculate its own previews (as it does currently for image thumbnails). We can then extend/optimise this to option 4 as a special extra if needed.
54
55 API
56 ---
57
58 GET /_matrix/media/r0/preview_url?url=http://wherever.com
59 200 OK
60 {
61 "og:type" : "article"
62 "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672"
63 "og:title" : "Matrix on Twitter"
64 "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png"
65 "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
66 "og:site_name" : "Twitter"
67 }
68
69 * Downloads the URL
70 * If HTML, just stores it in RAM and parses it for OG meta tags
71 * Download any media OG meta tags to the media repo, and refer to them in the OG via mxc:// URIs.
72 * If a media filetype we know we can thumbnail: store it on disk, and hand it to the thumbnailer. Generate OG meta tags from the thumbnailer contents.
73 * Otherwise, don't bother downloading further.
0 User Directory API Implementation
1 =================================
2
3 The user directory is currently maintained based on the 'visible' users
4 on this particular server - i.e. ones which your account shares a room with, or
5 who are present in a publicly viewable room present on the server.
6
7 The directory info is stored in various tables, which can (typically after
8 DB corruption) get stale or out of sync. If this happens, for now the
9 quickest solution to fix it is:
10
11 ```
12 UPDATE user_directory_stream_pos SET stream_id = NULL;
13 ```
14
15 and restart the synapse, which should then start a background task to
16 flush the current tables and regenerate the directory.
00 Scaling synapse via workers
1 ---------------------------
1 ===========================
22
33 Synapse has experimental support for splitting out functionality into
44 multiple separate python processes, helping greatly with scalability. These
55 processes are called 'workers', and are (eventually) intended to scale
66 horizontally independently.
77
8 All of the below is highly experimental and subject to change as Synapse evolves,
9 but documenting it here to help folks needing highly scalable Synapses similar
10 to the one running matrix.org!
11
812 All processes continue to share the same database instance, and as such, workers
913 only work with postgres based synapse deployments (sharing a single sqlite
1014 across multiple processes is a recipe for disaster, plus you should be using
1519 database replication; feeding a stream of relevant data to the workers so they
1620 can be kept in sync with the main synapse process and database state.
1721
18 To enable workers, you need to add a replication listener to the master synapse, e.g.::
22 Configuration
23 -------------
24
25 To make effective use of the workers, you will need to configure an HTTP
26 reverse-proxy such as nginx or haproxy, which will direct incoming requests to
27 the correct worker, or to the main synapse instance. Note that this includes
28 requests made to the federation port. The caveats regarding running a
29 reverse-proxy on the federation port still apply (see
30 https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
31
32 To enable workers, you need to add two replication listeners to the master
33 synapse, e.g.::
1934
2035 listeners:
36 # The TCP replication port
2137 - port: 9092
2238 bind_address: '127.0.0.1'
2339 type: replication
24
25 Under **no circumstances** should this replication API listener be exposed to the
26 public internet; it currently implements no authentication whatsoever and is
40 # The HTTP replication port
41 - port: 9093
42 bind_address: '127.0.0.1'
43 type: http
44 resources:
45 - names: [replication]
46
47 Under **no circumstances** should these replication API listeners be exposed to
48 the public internet; it currently implements no authentication whatsoever and is
2749 unencrypted.
2850
29 You then create a set of configs for the various worker processes. These should be
30 worker configuration files should be stored in a dedicated subdirectory, to allow
31 synctl to manipulate them.
32
33 The current available worker applications are:
34 * synapse.app.pusher - handles sending push notifications to sygnal and email
35 * synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
36 * synapse.app.appservice - handles output traffic to Application Services
37 * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
38 * synapse.app.media_repository - handles the media repository.
39 * synapse.app.client_reader - handles client API endpoints like /publicRooms
51 (Roughly, the TCP port is used for streaming data from the master to the
52 workers, and the HTTP port for the workers to send data to the main
53 synapse process.)
54
55 You then create a set of configs for the various worker processes. These
56 should be worker configuration files, and should be stored in a dedicated
57 subdirectory, to allow synctl to manipulate them.
4058
4159 Each worker configuration file inherits the configuration of the main homeserver
4260 configuration file. You can then override configuration specific to that worker,
4361 e.g. the HTTP listener that it provides (if any); logging configuration; etc.
4462 You should minimise the number of overrides though to maintain a usable config.
4563
46 You must specify the type of worker application (worker_app) and the replication
47 endpoint that it's talking to on the main synapse process (worker_replication_host
48 and worker_replication_port).
64 You must specify the type of worker application (``worker_app``). The currently
65 available worker applications are listed below. You must also specify the
66 replication endpoints that it's talking to on the main synapse process.
67 ``worker_replication_host`` should specify the host of the main synapse,
68 ``worker_replication_port`` should point to the TCP replication listener port and
69 ``worker_replication_http_port`` should point to the HTTP replication port.
70
71 Currently, only the ``event_creator`` worker requires specifying
72 ``worker_replication_http_port``.
4973
5074 For instance::
5175
5478 # The replication listener on the synapse to talk to.
5579 worker_replication_host: 127.0.0.1
5680 worker_replication_port: 9092
81 worker_replication_http_port: 9093
5782
5883 worker_listeners:
5984 - type: http
6792 worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
6893
6994 ...is a full configuration for a synchrotron worker instance, which will expose a
70 plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
95 plain HTTP ``/sync`` endpoint on port 8083 separately from the ``/sync`` endpoint provided
7196 by the main synapse.
7297
73 Obviously you should configure your loadbalancer to route the /sync endpoint to
74 the synchrotron instance(s) in this instance.
98 Obviously you should configure your reverse-proxy to route the relevant
99 endpoints to the worker (``localhost:8083`` in the above example).
75100
76101 Finally, to actually run your worker-based synapse, you must pass synctl the -a
77102 commandline option to tell it to operate on all the worker configurations found
88113
89114 synctl -w $CONFIG/workers/synchrotron.yaml restart
90115
91 All of the above is highly experimental and subject to change as Synapse evolves,
92 but documenting it here to help folks needing highly scalable Synapses similar
93 to the one running matrix.org!
116
117 Available worker applications
118 -----------------------------
119
120 ``synapse.app.pusher``
121 ~~~~~~~~~~~~~~~~~~~~~~
122
123 Handles sending push notifications to sygnal and email. Doesn't handle any
124 REST endpoints itself, but you should set ``start_pushers: False`` in the
125 shared configuration file to stop the main synapse sending these notifications.
126
127 Note this worker cannot be load-balanced: only one instance should be active.
128
129 ``synapse.app.synchrotron``
130 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
131
132 The synchrotron handles ``sync`` requests from clients. In particular, it can
133 handle REST endpoints matching the following regular expressions::
134
135 ^/_matrix/client/(v2_alpha|r0)/sync$
136 ^/_matrix/client/(api/v1|v2_alpha|r0)/events$
137 ^/_matrix/client/(api/v1|r0)/initialSync$
138 ^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
139
140 The above endpoints should all be routed to the synchrotron worker by the
141 reverse-proxy configuration.
142
143 It is possible to run multiple instances of the synchrotron to scale
144 horizontally. In this case the reverse-proxy should be configured to
145 load-balance across the instances, though it will be more efficient if all
146 requests from a particular user are routed to a single instance. Extracting
147 a userid from the access token is currently left as an exercise for the reader.
148
149 ``synapse.app.appservice``
150 ~~~~~~~~~~~~~~~~~~~~~~~~~~
151
152 Handles sending output traffic to Application Services. Doesn't handle any
153 REST endpoints itself, but you should set ``notify_appservices: False`` in the
154 shared configuration file to stop the main synapse sending these notifications.
155
156 Note this worker cannot be load-balanced: only one instance should be active.
157
158 ``synapse.app.federation_reader``
159 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160
161 Handles a subset of federation endpoints. In particular, it can handle REST
162 endpoints matching the following regular expressions::
163
164 ^/_matrix/federation/v1/event/
165 ^/_matrix/federation/v1/state/
166 ^/_matrix/federation/v1/state_ids/
167 ^/_matrix/federation/v1/backfill/
168 ^/_matrix/federation/v1/get_missing_events/
169 ^/_matrix/federation/v1/publicRooms
170
171 The above endpoints should all be routed to the federation_reader worker by the
172 reverse-proxy configuration.
173
174 ``synapse.app.federation_sender``
175 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176
177 Handles sending federation traffic to other servers. Doesn't handle any
178 REST endpoints itself, but you should set ``send_federation: False`` in the
179 shared configuration file to stop the main synapse sending this traffic.
180
181 Note this worker cannot be load-balanced: only one instance should be active.
182
183 ``synapse.app.media_repository``
184 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
185
186 Handles the media repository. It can handle all endpoints starting with::
187
188 /_matrix/media/
189
190 You should also set ``enable_media_repo: False`` in the shared configuration
191 file to stop the main synapse running background jobs related to managing the
192 media repository.
193
194 Note this worker cannot be load-balanced: only one instance should be active.
195
196 ``synapse.app.client_reader``
197 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
198
199 Handles client API endpoints. It can handle REST endpoints matching the
200 following regular expressions::
201
202 ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
203
204 ``synapse.app.user_dir``
205 ~~~~~~~~~~~~~~~~~~~~~~~~
206
207 Handles searches in the user directory. It can handle REST endpoints matching
208 the following regular expressions::
209
210 ^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
211
212 ``synapse.app.frontend_proxy``
213 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214
215 Proxies some frequently-requested client endpoints to add caching and remove
216 load from the main synapse. It can handle REST endpoints matching the following
217 regular expressions::
218
219 ^/_matrix/client/(api/v1|r0|unstable)/keys/upload
220
221 It will proxy any requests it cannot handle to the main synapse instance. It
222 must therefore be configured with the location of the main instance, via
223 the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
224 file. For example::
225
226 worker_main_http_uri: http://127.0.0.1:8008
227
228
229 ``synapse.app.event_creator``
230 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
232 Handles non-state event creation. It can handle REST endpoints matching::
233
234 ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
235
236 It will create events locally and then send them on to the main synapse
237 instance to be persisted and handled.
0 #!/usr/bin/env python
1 # -*- coding: utf-8 -*-
2 # Copyright 2017 New Vector Ltd
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """
17 Moves a list of remote media from one media store to another.
18
19 The input should be a list of media files to be moved, one per line. Each line
20 should be formatted::
21
22 <origin server>|<file id>
23
24 This can be extracted from postgres with::
25
26 psql --tuples-only -A -c "select media_origin, filesystem_id from
27 matrix.remote_media_cache where ..."
28
29 To use, pipe the above into::
30
31 PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
32 """
33
34 from __future__ import print_function
35
36 import argparse
37 import logging
38
39 import sys
40
41 import os
42
43 import shutil
44
45 from synapse.rest.media.v1.filepath import MediaFilePaths
46
47 logger = logging.getLogger()
48
49
50 def main(src_repo, dest_repo):
51 src_paths = MediaFilePaths(src_repo)
52 dest_paths = MediaFilePaths(dest_repo)
53 for line in sys.stdin:
54 line = line.strip()
55 parts = line.split('|')
56 if len(parts) != 2:
57 print("Unable to parse input line %s" % line, file=sys.stderr)
58 exit(1)
59
60 move_media(parts[0], parts[1], src_paths, dest_paths)
61
62
63 def move_media(origin_server, file_id, src_paths, dest_paths):
64 """Move the given file, and any thumbnails, to the dest repo
65
66 Args:
67 origin_server (str):
68 file_id (str):
69 src_paths (MediaFilePaths):
70 dest_paths (MediaFilePaths):
71 """
72 logger.info("%s/%s", origin_server, file_id)
73
74 # check that the original exists
75 original_file = src_paths.remote_media_filepath(origin_server, file_id)
76 if not os.path.exists(original_file):
77 logger.warn(
78 "Original for %s/%s (%s) does not exist",
79 origin_server, file_id, original_file,
80 )
81 else:
82 mkdir_and_move(
83 original_file,
84 dest_paths.remote_media_filepath(origin_server, file_id),
85 )
86
87 # now look for thumbnails
88 original_thumb_dir = src_paths.remote_media_thumbnail_dir(
89 origin_server, file_id,
90 )
91 if not os.path.exists(original_thumb_dir):
92 return
93
94 mkdir_and_move(
95 original_thumb_dir,
96 dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
97 )
98
99
100 def mkdir_and_move(original_file, dest_file):
101 dirname = os.path.dirname(dest_file)
102 if not os.path.exists(dirname):
103 logger.debug("mkdir %s", dirname)
104 os.makedirs(dirname)
105 logger.debug("mv %s %s", original_file, dest_file)
106 shutil.move(original_file, dest_file)
107
108
109 if __name__ == "__main__":
110 parser = argparse.ArgumentParser(
111 description=__doc__,
112 formatter_class = argparse.RawDescriptionHelpFormatter,
113 )
114 parser.add_argument(
115 "-v", action='store_true', help='enable debug logging')
116 parser.add_argument(
117 "src_repo",
118 help="Path to source content repo",
119 )
120 parser.add_argument(
121 "dest_repo",
122 help="Path to source content repo",
123 )
124 args = parser.parse_args()
125
126 logging_config = {
127 "level": logging.DEBUG if args.v else logging.INFO,
128 "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
129 }
130 logging.basicConfig(**logging_config)
131
132 main(args.src_repo, args.dest_repo)
4141 "public_room_list_stream": ["visibility"],
4242 "device_lists_outbound_pokes": ["sent"],
4343 "users_who_share_rooms": ["share_private"],
44 "groups": ["is_public"],
45 "group_rooms": ["is_public"],
46 "group_users": ["is_public", "is_admin"],
47 "group_summary_rooms": ["is_public"],
48 "group_room_categories": ["is_public"],
49 "group_summary_users": ["is_public"],
50 "group_roles": ["is_public"],
51 "local_group_membership": ["is_publicised", "is_admin"],
4452 }
4553
4654
111119
112120 _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
113121 _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
122 _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
114123
115124 def runInteraction(self, desc, func, *args, **kwargs):
116125 def r(conn):
317326 backward_chunk = min(row[0] for row in brows) - 1
318327
319328 rows = frows + brows
320 self._convert_rows(table, headers, rows)
329 rows = self._convert_rows(table, headers, rows)
321330
322331 def insert(txn):
323332 self.postgres_store.insert_many_txn(
553562 i for i, h in enumerate(headers) if h in bool_col_names
554563 ]
555564
565 class BadValueException(Exception):
566 pass
567
556568 def conv(j, col):
557569 if j in bool_cols:
558570 return bool(col)
571 elif isinstance(col, basestring) and "\0" in col:
572 logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
573 raise BadValueException();
559574 return col
560575
576 outrows = []
561577 for i, row in enumerate(rows):
562 rows[i] = tuple(
563 conv(j, col)
564 for j, col in enumerate(row)
565 if j > 0
566 )
578 try:
579 outrows.append(tuple(
580 conv(j, col)
581 for j, col in enumerate(row)
582 if j > 0
583 ))
584 except BadValueException:
585 pass
586
587 return outrows
567588
568589 @defer.inlineCallbacks
569590 def _setup_sent_transactions(self):
591612 "select", r,
592613 )
593614
594 self._convert_rows("sent_transactions", headers, rows)
615 rows = self._convert_rows("sent_transactions", headers, rows)
595616
596617 inserted_rows = len(rows)
597618 if inserted_rows:
0 #!/usr/bin/env perl
1
2 use strict;
3 use warnings;
4
5 use JSON::XS;
6 use LWP::UserAgent;
7 use URI::Escape;
8
9 if (@ARGV < 4) {
10 die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
11 }
12
13 my ($hs, $access_token, $room_id, $group_id) = @ARGV;
14 my $ua = LWP::UserAgent->new();
15 $ua->timeout(10);
16
17 if ($room_id =~ /^#/) {
18 $room_id = uri_escape($room_id);
19 $room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
20 }
21
22 my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
23 my $group_users = [
24 (map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
25 (map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
26 ];
27
28 die "refusing to sync from empty room" unless (@$room_users);
29 die "refusing to sync to empty group" unless (@$group_users);
30
31 my $diff = {};
32 foreach my $user (@$room_users) { $diff->{$user}++ }
33 foreach my $user (@$group_users) { $diff->{$user}-- }
34
35 foreach my $user (keys %$diff) {
36 if ($diff->{$user} == 1) {
37 warn "inviting $user";
38 print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
39 }
40 elsif ($diff->{$user} == -1) {
41 warn "removing $user";
42 print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
43 }
44 }
122122 except:
123123 return "https://%s:%d%s" % (destination, 8448, path)
124124
125 def get_json(origin_name, origin_key, destination, path):
126 request_json = {
127 "method": "GET",
125
126 def request_json(method, origin_name, origin_key, destination, path, content):
127 if method is None:
128 if content is None:
129 method = "GET"
130 else:
131 method = "POST"
132
133 json_to_sign = {
134 "method": method,
128135 "uri": path,
129136 "origin": origin_name,
130137 "destination": destination,
131138 }
132139
133 signed_json = sign_json(request_json, origin_key, origin_name)
140 if content is not None:
141 json_to_sign["content"] = json.loads(content)
142
143 signed_json = sign_json(json_to_sign, origin_key, origin_name)
134144
135145 authorization_headers = []
136146
144154 dest = lookup(destination, path)
145155 print ("Requesting %s" % dest, file=sys.stderr)
146156
147 result = requests.get(
148 dest,
157 result = requests.request(
158 method=method,
159 url=dest,
149160 headers={"Authorization": authorization_headers[0]},
150161 verify=False,
162 data=content,
151163 )
152164 sys.stderr.write("Status Code: %d\n" % (result.status_code,))
153165 return result.json()
186198 )
187199
188200 parser.add_argument(
201 "-X", "--method",
202 help="HTTP method to use for the request. Defaults to GET if --data is"
203 "unspecified, POST if it is."
204 )
205
206 parser.add_argument(
207 "--body",
208 help="Data to send as the body of the HTTP request"
209 )
210
211 parser.add_argument(
189212 "path",
190213 help="request path. We will add '/_matrix/federation/v1/' to this."
191214 )
198221 with open(args.signing_key_path) as f:
199222 key = read_signing_keys(f)[0]
200223
201 result = get_json(
202 args.server_name, key, args.destination, "/_matrix/federation/v1/" + args.path
224 result = request_json(
225 args.method,
226 args.server_name, key, args.destination,
227 "/_matrix/federation/v1/" + args.path,
228 content=args.body,
203229 )
204230
205231 json.dump(result, sys.stdout)
1515 """ This is a reference implementation of a Matrix home server.
1616 """
1717
18 __version__ = "0.24.0"
18 __version__ = "0.27.2"
269269 rights (str): The operation being performed; the access token must
270270 allow this.
271271 Returns:
272 dict : dict that includes the user and the ID of their access token.
272 Deferred[dict]: dict that includes:
273 `user` (UserID)
274 `is_guest` (bool)
275 `token_id` (int|None): access token id. May be None if guest
276 `device_id` (str|None): device corresponding to access token
273277 Raises:
274278 AuthError if no user by that token exists or the token is invalid.
275279 """
4545 THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
4646 THREEPID_IN_USE = "M_THREEPID_IN_USE"
4747 THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
48 THREEPID_DENIED = "M_THREEPID_DENIED"
4849 INVALID_USERNAME = "M_INVALID_USERNAME"
4950 SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
5051
137138 class RegistrationError(SynapseError):
138139 """An error raised when a registration event fails."""
139140 pass
141
142
143 class FederationDeniedError(SynapseError):
144 """An error raised when the server tries to federate with a server which
145 is not on its federation whitelist.
146
147 Attributes:
148 destination (str): The destination which has been denied
149 """
150
151 def __init__(self, destination):
152 """Raised by federation client or server to indicate that we are
153 are deliberately not attempting to contact a given server because it is
154 not on our federation whitelist.
155
156 Args:
157 destination (str): the domain in question
158 """
159
160 self.destination = destination
161
162 super(FederationDeniedError, self).__init__(
163 code=403,
164 msg="Federation denied with %s." % (self.destination,),
165 errcode=Codes.FORBIDDEN,
166 )
167
168
169 class InteractiveAuthIncompleteError(Exception):
170 """An error raised when UI auth is not yet complete
171
172 (This indicates we should return a 401 with 'result' as the body)
173
174 Attributes:
175 result (dict): the server response to the request, which should be
176 passed back to the client
177 """
178 def __init__(self, result):
179 super(InteractiveAuthIncompleteError, self).__init__(
180 "Interactive auth not yet complete",
181 )
182 self.result = result
140183
141184
142185 class UnrecognizedRequestError(SynapseError):
1616 from synapse.types import UserID, RoomID
1717 from twisted.internet import defer
1818
19 import ujson as json
19 import simplejson as json
2020 import jsonschema
2121 from jsonschema import FormatChecker
2222
1818
1919 try:
2020 import affinity
21 except:
21 except Exception:
2222 affinity = None
2323
2424 from daemonize import Daemonize
2525 from synapse.util import PreserveLoggingContext
2626 from synapse.util.rlimit import change_resource_limit
27 from twisted.internet import reactor
27 from twisted.internet import error, reactor
28
29 logger = logging.getLogger(__name__)
2830
2931
3032 def start_worker_reactor(appname, config):
119121 sys.stderr.write(" %s\n" % (line.rstrip(),))
120122 sys.stderr.write("*" * line_length + '\n')
121123 sys.exit(1)
124
125
126 def listen_tcp(bind_addresses, port, factory, backlog=50):
127 """
128 Create a TCP socket for a port and several addresses
129 """
130 for address in bind_addresses:
131 try:
132 reactor.listenTCP(
133 port,
134 factory,
135 backlog,
136 address
137 )
138 except error.CannotListenError as e:
139 check_bind_error(e, address, bind_addresses)
140
141
142 def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
143 """
144 Create an SSL socket for a port and several addresses
145 """
146 for address in bind_addresses:
147 try:
148 reactor.listenSSL(
149 port,
150 factory,
151 context_factory,
152 backlog,
153 address
154 )
155 except error.CannotListenError as e:
156 check_bind_error(e, address, bind_addresses)
157
158
159 def check_bind_error(e, address, bind_addresses):
160 """
161 This method checks an exception occurred while binding on 0.0.0.0.
162 If :: is specified in the bind addresses a warning is shown.
163 The exception is still raised otherwise.
164
165 Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
166 because :: binds on both IPv4 and IPv6 (as per RFC 3493).
167 When binding on 0.0.0.0 after :: this can safely be ignored.
168
169 Args:
170 e (Exception): Exception that was caught.
171 address (str): Address on which binding was attempted.
172 bind_addresses (list): Addresses on which the service listens.
173 """
174 if address == '0.0.0.0' and '::' in bind_addresses:
175 logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
176 else:
177 raise e
4848
4949
5050 class AppserviceServer(HomeServer):
51 def get_db_conn(self, run_new_connection=True):
52 # Any param beginning with cp_ is a parameter for adbapi, and should
53 # not be passed to the database engine.
54 db_params = {
55 k: v for k, v in self.db_config.get("args", {}).items()
56 if not k.startswith("cp_")
57 }
58 db_conn = self.database_engine.module.connect(**db_params)
59
60 if run_new_connection:
61 self.database_engine.on_new_connection(db_conn)
62 return db_conn
63
6451 def setup(self):
6552 logger.info("Setting up.")
6653 self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
7865
7966 root_resource = create_resource_tree(resources, Resource())
8067
81 for address in bind_addresses:
82 reactor.listenTCP(
83 port,
84 SynapseSite(
85 "synapse.access.http.%s" % (site_tag,),
86 site_tag,
87 listener_config,
88 root_resource,
89 ),
90 interface=address
68 _base.listen_tcp(
69 bind_addresses,
70 port,
71 SynapseSite(
72 "synapse.access.http.%s" % (site_tag,),
73 site_tag,
74 listener_config,
75 root_resource,
9176 )
77 )
9278
9379 logger.info("Synapse appservice now listening on port %d", port)
9480
9783 if listener["type"] == "http":
9884 self._listen_http(listener)
9985 elif listener["type"] == "manhole":
100 bind_addresses = listener["bind_addresses"]
101
102 for address in bind_addresses:
103 reactor.listenTCP(
104 listener["port"],
105 manhole(
106 username="matrix",
107 password="rabbithole",
108 globals={"hs": self},
109 ),
110 interface=address
86 _base.listen_tcp(
87 listener["bind_addresses"],
88 listener["port"],
89 manhole(
90 username="matrix",
91 password="rabbithole",
92 globals={"hs": self},
11193 )
94 )
11295 else:
11396 logger.warn("Unrecognized listener type: %s", listener["type"])
11497
6363
6464
6565 class ClientReaderServer(HomeServer):
66 def get_db_conn(self, run_new_connection=True):
67 # Any param beginning with cp_ is a parameter for adbapi, and should
68 # not be passed to the database engine.
69 db_params = {
70 k: v for k, v in self.db_config.get("args", {}).items()
71 if not k.startswith("cp_")
72 }
73 db_conn = self.database_engine.module.connect(**db_params)
74
75 if run_new_connection:
76 self.database_engine.on_new_connection(db_conn)
77 return db_conn
78
7966 def setup(self):
8067 logger.info("Setting up.")
8168 self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
10289
10390 root_resource = create_resource_tree(resources, Resource())
10491
105 for address in bind_addresses:
106 reactor.listenTCP(
107 port,
108 SynapseSite(
109 "synapse.access.http.%s" % (site_tag,),
110 site_tag,
111 listener_config,
112 root_resource,
113 ),
114 interface=address
92 _base.listen_tcp(
93 bind_addresses,
94 port,
95 SynapseSite(
96 "synapse.access.http.%s" % (site_tag,),
97 site_tag,
98 listener_config,
99 root_resource,
115100 )
101 )
116102
117103 logger.info("Synapse client reader now listening on port %d", port)
118104
121107 if listener["type"] == "http":
122108 self._listen_http(listener)
123109 elif listener["type"] == "manhole":
124 bind_addresses = listener["bind_addresses"]
110 _base.listen_tcp(
111 listener["bind_addresses"],
112 listener["port"],
113 manhole(
114 username="matrix",
115 password="rabbithole",
116 globals={"hs": self},
117 )
118 )
125119
126 for address in bind_addresses:
127 reactor.listenTCP(
128 listener["port"],
129 manhole(
130 username="matrix",
131 password="rabbithole",
132 globals={"hs": self},
133 ),
134 interface=address
135 )
136120 else:
137121 logger.warn("Unrecognized listener type: %s", listener["type"])
138122
171155 )
172156
173157 ss.setup()
174 ss.get_handlers()
175158 ss.start_listening(config.worker_listeners)
176159
177160 def start():
0 #!/usr/bin/env python
1 # -*- coding: utf-8 -*-
2 # Copyright 2018 New Vector Ltd
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 import logging
16 import sys
17
18 import synapse
19 from synapse import events
20 from synapse.app import _base
21 from synapse.config._base import ConfigError
22 from synapse.config.homeserver import HomeServerConfig
23 from synapse.config.logger import setup_logging
24 from synapse.crypto import context_factory
25 from synapse.http.server import JsonResource
26 from synapse.http.site import SynapseSite
27 from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
28 from synapse.replication.slave.storage._base import BaseSlavedStore
29 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
30 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
31 from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
32 from synapse.replication.slave.storage.devices import SlavedDeviceStore
33 from synapse.replication.slave.storage.directory import DirectoryStore
34 from synapse.replication.slave.storage.events import SlavedEventStore
35 from synapse.replication.slave.storage.profile import SlavedProfileStore
36 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
37 from synapse.replication.slave.storage.pushers import SlavedPusherStore
38 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
39 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
40 from synapse.replication.slave.storage.room import RoomStore
41 from synapse.replication.slave.storage.transactions import TransactionStore
42 from synapse.replication.tcp.client import ReplicationClientHandler
43 from synapse.rest.client.v1.room import (
44 RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet,
45 JoinRoomAliasServlet,
46 )
47 from synapse.server import HomeServer
48 from synapse.storage.engines import create_engine
49 from synapse.util.httpresourcetree import create_resource_tree
50 from synapse.util.logcontext import LoggingContext
51 from synapse.util.manhole import manhole
52 from synapse.util.versionstring import get_version_string
53 from twisted.internet import reactor
54 from twisted.web.resource import Resource
55
56 logger = logging.getLogger("synapse.app.event_creator")
57
58
59 class EventCreatorSlavedStore(
60 DirectoryStore,
61 TransactionStore,
62 SlavedProfileStore,
63 SlavedAccountDataStore,
64 SlavedPusherStore,
65 SlavedReceiptsStore,
66 SlavedPushRuleStore,
67 SlavedDeviceStore,
68 SlavedClientIpStore,
69 SlavedApplicationServiceStore,
70 SlavedEventStore,
71 SlavedRegistrationStore,
72 RoomStore,
73 BaseSlavedStore,
74 ):
75 pass
76
77
78 class EventCreatorServer(HomeServer):
79 def setup(self):
80 logger.info("Setting up.")
81 self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
82 logger.info("Finished setting up.")
83
84 def _listen_http(self, listener_config):
85 port = listener_config["port"]
86 bind_addresses = listener_config["bind_addresses"]
87 site_tag = listener_config.get("tag", port)
88 resources = {}
89 for res in listener_config["resources"]:
90 for name in res["names"]:
91 if name == "metrics":
92 resources[METRICS_PREFIX] = MetricsResource(self)
93 elif name == "client":
94 resource = JsonResource(self, canonical_json=False)
95 RoomSendEventRestServlet(self).register(resource)
96 RoomMembershipRestServlet(self).register(resource)
97 RoomStateEventRestServlet(self).register(resource)
98 JoinRoomAliasServlet(self).register(resource)
99 resources.update({
100 "/_matrix/client/r0": resource,
101 "/_matrix/client/unstable": resource,
102 "/_matrix/client/v2_alpha": resource,
103 "/_matrix/client/api/v1": resource,
104 })
105
106 root_resource = create_resource_tree(resources, Resource())
107
108 _base.listen_tcp(
109 bind_addresses,
110 port,
111 SynapseSite(
112 "synapse.access.http.%s" % (site_tag,),
113 site_tag,
114 listener_config,
115 root_resource,
116 )
117 )
118
119 logger.info("Synapse event creator now listening on port %d", port)
120
121 def start_listening(self, listeners):
122 for listener in listeners:
123 if listener["type"] == "http":
124 self._listen_http(listener)
125 elif listener["type"] == "manhole":
126 _base.listen_tcp(
127 listener["bind_addresses"],
128 listener["port"],
129 manhole(
130 username="matrix",
131 password="rabbithole",
132 globals={"hs": self},
133 )
134 )
135 else:
136 logger.warn("Unrecognized listener type: %s", listener["type"])
137
138 self.get_tcp_replication().start_replication(self)
139
140 def build_tcp_replication(self):
141 return ReplicationClientHandler(self.get_datastore())
142
143
144 def start(config_options):
145 try:
146 config = HomeServerConfig.load_config(
147 "Synapse event creator", config_options
148 )
149 except ConfigError as e:
150 sys.stderr.write("\n" + e.message + "\n")
151 sys.exit(1)
152
153 assert config.worker_app == "synapse.app.event_creator"
154
155 assert config.worker_replication_http_port is not None
156
157 setup_logging(config, use_worker_options=True)
158
159 events.USE_FROZEN_DICTS = config.use_frozen_dicts
160
161 database_engine = create_engine(config.database_config)
162
163 tls_server_context_factory = context_factory.ServerContextFactory(config)
164
165 ss = EventCreatorServer(
166 config.server_name,
167 db_config=config.database_config,
168 tls_server_context_factory=tls_server_context_factory,
169 config=config,
170 version_string="Synapse/" + get_version_string(synapse),
171 database_engine=database_engine,
172 )
173
174 ss.setup()
175 ss.start_listening(config.worker_listeners)
176
177 def start():
178 ss.get_state_handler().start_caching()
179 ss.get_datastore().start_profiling()
180
181 reactor.callWhenRunning(start)
182
183 _base.start_worker_reactor("synapse-event-creator", config)
184
185
186 if __name__ == '__main__':
187 with LoggingContext("main"):
188 start(sys.argv[1:])
5757
5858
5959 class FederationReaderServer(HomeServer):
60 def get_db_conn(self, run_new_connection=True):
61 # Any param beginning with cp_ is a parameter for adbapi, and should
62 # not be passed to the database engine.
63 db_params = {
64 k: v for k, v in self.db_config.get("args", {}).items()
65 if not k.startswith("cp_")
66 }
67 db_conn = self.database_engine.module.connect(**db_params)
68
69 if run_new_connection:
70 self.database_engine.on_new_connection(db_conn)
71 return db_conn
72
7360 def setup(self):
7461 logger.info("Setting up.")
7562 self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
9178
9279 root_resource = create_resource_tree(resources, Resource())
9380
94 for address in bind_addresses:
95 reactor.listenTCP(
96 port,
97 SynapseSite(
98 "synapse.access.http.%s" % (site_tag,),
99 site_tag,
100 listener_config,
101 root_resource,
102 ),
103 interface=address
81 _base.listen_tcp(
82 bind_addresses,
83 port,
84 SynapseSite(
85 "synapse.access.http.%s" % (site_tag,),
86 site_tag,
87 listener_config,
88 root_resource,
10489 )
90 )
10591
10692 logger.info("Synapse federation reader now listening on port %d", port)
10793
11096 if listener["type"] == "http":
11197 self._listen_http(listener)
11298 elif listener["type"] == "manhole":
113 bind_addresses = listener["bind_addresses"]
114
115 for address in bind_addresses:
116 reactor.listenTCP(
117 listener["port"],
118 manhole(
119 username="matrix",
120 password="rabbithole",
121 globals={"hs": self},
122 ),
123 interface=address
99 _base.listen_tcp(
100 listener["bind_addresses"],
101 listener["port"],
102 manhole(
103 username="matrix",
104 password="rabbithole",
105 globals={"hs": self},
124106 )
107 )
125108 else:
126109 logger.warn("Unrecognized listener type: %s", listener["type"])
127110
160143 )
161144
162145 ss.setup()
163 ss.get_handlers()
164146 ss.start_listening(config.worker_listeners)
165147
166148 def start():
7575
7676
7777 class FederationSenderServer(HomeServer):
78 def get_db_conn(self, run_new_connection=True):
79 # Any param beginning with cp_ is a parameter for adbapi, and should
80 # not be passed to the database engine.
81 db_params = {
82 k: v for k, v in self.db_config.get("args", {}).items()
83 if not k.startswith("cp_")
84 }
85 db_conn = self.database_engine.module.connect(**db_params)
86
87 if run_new_connection:
88 self.database_engine.on_new_connection(db_conn)
89 return db_conn
90
9178 def setup(self):
9279 logger.info("Setting up.")
9380 self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
10592
10693 root_resource = create_resource_tree(resources, Resource())
10794
108 for address in bind_addresses:
109 reactor.listenTCP(
110 port,
111 SynapseSite(
112 "synapse.access.http.%s" % (site_tag,),
113 site_tag,
114 listener_config,
115 root_resource,
116 ),
117 interface=address
95 _base.listen_tcp(
96 bind_addresses,
97 port,
98 SynapseSite(
99 "synapse.access.http.%s" % (site_tag,),
100 site_tag,
101 listener_config,
102 root_resource,
118103 )
104 )
119105
120106 logger.info("Synapse federation_sender now listening on port %d", port)
121107
124110 if listener["type"] == "http":
125111 self._listen_http(listener)
126112 elif listener["type"] == "manhole":
127 bind_addresses = listener["bind_addresses"]
128
129 for address in bind_addresses:
130 reactor.listenTCP(
131 listener["port"],
132 manhole(
133 username="matrix",
134 password="rabbithole",
135 globals={"hs": self},
136 ),
137 interface=address
113 _base.listen_tcp(
114 listener["bind_addresses"],
115 listener["port"],
116 manhole(
117 username="matrix",
118 password="rabbithole",
119 globals={"hs": self},
138120 )
121 )
139122 else:
140123 logger.warn("Unrecognized listener type: %s", listener["type"])
141124
4949
5050
5151 class KeyUploadServlet(RestServlet):
52 PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
53 releases=())
52 PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
5453
5554 def __init__(self, hs):
5655 """
8887
8988 if body:
9089 # They're actually trying to upload something, proxy to main synapse.
90 # Pass through the auth headers, if any, in case the access token
91 # is there.
92 auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
93 headers = {
94 "Authorization": auth_headers,
95 }
9196 result = yield self.http_client.post_json_get_json(
9297 self.main_uri + request.uri,
9398 body,
99 headers=headers,
94100 )
95101
96102 defer.returnValue((200, result))
111117
112118
113119 class FrontendProxyServer(HomeServer):
114 def get_db_conn(self, run_new_connection=True):
115 # Any param beginning with cp_ is a parameter for adbapi, and should
116 # not be passed to the database engine.
117 db_params = {
118 k: v for k, v in self.db_config.get("args", {}).items()
119 if not k.startswith("cp_")
120 }
121 db_conn = self.database_engine.module.connect(**db_params)
122
123 if run_new_connection:
124 self.database_engine.on_new_connection(db_conn)
125 return db_conn
126
127120 def setup(self):
128121 logger.info("Setting up.")
129122 self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
150143
151144 root_resource = create_resource_tree(resources, Resource())
152145
153 for address in bind_addresses:
154 reactor.listenTCP(
155 port,
156 SynapseSite(
157 "synapse.access.http.%s" % (site_tag,),
158 site_tag,
159 listener_config,
160 root_resource,
161 ),
162 interface=address
146 _base.listen_tcp(
147 bind_addresses,
148 port,
149 SynapseSite(
150 "synapse.access.http.%s" % (site_tag,),
151 site_tag,
152 listener_config,
153 root_resource,
163154 )
155 )
164156
165157 logger.info("Synapse client reader now listening on port %d", port)
166158
169161 if listener["type"] == "http":
170162 self._listen_http(listener)
171163 elif listener["type"] == "manhole":
172 bind_addresses = listener["bind_addresses"]
173
174 for address in bind_addresses:
175 reactor.listenTCP(
176 listener["port"],
177 manhole(
178 username="matrix",
179 password="rabbithole",
180 globals={"hs": self},
181 ),
182 interface=address
164 _base.listen_tcp(
165 listener["bind_addresses"],
166 listener["port"],
167 manhole(
168 username="matrix",
169 password="rabbithole",
170 globals={"hs": self},
183171 )
172 )
184173 else:
185174 logger.warn("Unrecognized listener type: %s", listener["type"])
186175
221210 )
222211
223212 ss.setup()
224 ss.get_handlers()
225213 ss.start_listening(config.worker_listeners)
226214
227215 def start():
2424 LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
2525 STATIC_PREFIX, WEB_CLIENT_PREFIX
2626 from synapse.app import _base
27 from synapse.app._base import quit_with_error
27 from synapse.app._base import quit_with_error, listen_ssl, listen_tcp
2828 from synapse.config._base import ConfigError
2929 from synapse.config.homeserver import HomeServerConfig
3030 from synapse.crypto import context_factory
3131 from synapse.federation.transport.server import TransportLayerServer
32 from synapse.module_api import ModuleApi
33 from synapse.http.additional_resource import AdditionalResource
3234 from synapse.http.server import RootRedirect
3335 from synapse.http.site import SynapseSite
3436 from synapse.metrics import register_memory_metrics
3537 from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
3638 from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
3739 check_requirements
40 from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
3841 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
3942 from synapse.rest import ClientRestResource
4043 from synapse.rest.key.v1.server_key_resource import LocalKey
4144 from synapse.rest.key.v2 import KeyApiV2Resource
4245 from synapse.rest.media.v0.content_repository import ContentRepoResource
43 from synapse.rest.media.v1.media_repository import MediaRepositoryResource
4446 from synapse.server import HomeServer
4547 from synapse.storage import are_all_users_on_domain
4648 from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
4850 from synapse.util.httpresourcetree import create_resource_tree
4951 from synapse.util.logcontext import LoggingContext
5052 from synapse.util.manhole import manhole
53 from synapse.util.module_loader import load_module
5154 from synapse.util.rlimit import change_resource_limit
5255 from synapse.util.versionstring import get_version_string
5356 from twisted.application import service
106109 resources = {}
107110 for res in listener_config["resources"]:
108111 for name in res["names"]:
109 if name == "client":
110 client_resource = ClientRestResource(self)
111 if res["compress"]:
112 client_resource = gz_wrap(client_resource)
113
114 resources.update({
115 "/_matrix/client/api/v1": client_resource,
116 "/_matrix/client/r0": client_resource,
117 "/_matrix/client/unstable": client_resource,
118 "/_matrix/client/v2_alpha": client_resource,
119 "/_matrix/client/versions": client_resource,
120 })
121
122 if name == "federation":
123 resources.update({
124 FEDERATION_PREFIX: TransportLayerServer(self),
125 })
126
127 if name in ["static", "client"]:
128 resources.update({
129 STATIC_PREFIX: File(
130 os.path.join(os.path.dirname(synapse.__file__), "static")
131 ),
132 })
133
134 if name in ["media", "federation", "client"]:
135 media_repo = MediaRepositoryResource(self)
136 resources.update({
137 MEDIA_PREFIX: media_repo,
138 LEGACY_MEDIA_PREFIX: media_repo,
139 CONTENT_REPO_PREFIX: ContentRepoResource(
140 self, self.config.uploads_path
141 ),
142 })
143
144 if name in ["keys", "federation"]:
145 resources.update({
146 SERVER_KEY_PREFIX: LocalKey(self),
147 SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
148 })
149
150 if name == "webclient":
151 resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
152
153 if name == "metrics" and self.get_config().enable_metrics:
154 resources[METRICS_PREFIX] = MetricsResource(self)
112 resources.update(self._configure_named_resource(
113 name, res.get("compress", False),
114 ))
115
116 additional_resources = listener_config.get("additional_resources", {})
117 logger.debug("Configuring additional resources: %r",
118 additional_resources)
119 module_api = ModuleApi(self, self.get_auth_handler())
120 for path, resmodule in additional_resources.items():
121 handler_cls, config = load_module(resmodule)
122 handler = handler_cls(config, module_api)
123 resources[path] = AdditionalResource(self, handler.handle_request)
155124
156125 if WEB_CLIENT_PREFIX in resources:
157126 root_resource = RootRedirect(WEB_CLIENT_PREFIX)
161130 root_resource = create_resource_tree(resources, root_resource)
162131
163132 if tls:
164 for address in bind_addresses:
165 reactor.listenSSL(
166 port,
167 SynapseSite(
168 "synapse.access.https.%s" % (site_tag,),
169 site_tag,
170 listener_config,
171 root_resource,
133 listen_ssl(
134 bind_addresses,
135 port,
136 SynapseSite(
137 "synapse.access.https.%s" % (site_tag,),
138 site_tag,
139 listener_config,
140 root_resource,
141 ),
142 self.tls_server_context_factory,
143 )
144
145 else:
146 listen_tcp(
147 bind_addresses,
148 port,
149 SynapseSite(
150 "synapse.access.http.%s" % (site_tag,),
151 site_tag,
152 listener_config,
153 root_resource,
154 )
155 )
156 logger.info("Synapse now listening on port %d", port)
157
158 def _configure_named_resource(self, name, compress=False):
159 """Build a resource map for a named resource
160
161 Args:
162 name (str): named resource: one of "client", "federation", etc
163 compress (bool): whether to enable gzip compression for this
164 resource
165
166 Returns:
167 dict[str, Resource]: map from path to HTTP resource
168 """
169 resources = {}
170 if name == "client":
171 client_resource = ClientRestResource(self)
172 if compress:
173 client_resource = gz_wrap(client_resource)
174
175 resources.update({
176 "/_matrix/client/api/v1": client_resource,
177 "/_matrix/client/r0": client_resource,
178 "/_matrix/client/unstable": client_resource,
179 "/_matrix/client/v2_alpha": client_resource,
180 "/_matrix/client/versions": client_resource,
181 })
182
183 if name == "federation":
184 resources.update({
185 FEDERATION_PREFIX: TransportLayerServer(self),
186 })
187
188 if name in ["static", "client"]:
189 resources.update({
190 STATIC_PREFIX: File(
191 os.path.join(os.path.dirname(synapse.__file__), "static")
192 ),
193 })
194
195 if name in ["media", "federation", "client"]:
196 if self.get_config().enable_media_repo:
197 media_repo = self.get_media_repository_resource()
198 resources.update({
199 MEDIA_PREFIX: media_repo,
200 LEGACY_MEDIA_PREFIX: media_repo,
201 CONTENT_REPO_PREFIX: ContentRepoResource(
202 self, self.config.uploads_path
172203 ),
173 self.tls_server_context_factory,
174 interface=address
204 })
205 elif name == "media":
206 raise ConfigError(
207 "'media' resource conflicts with enable_media_repo=False",
175208 )
176 else:
177 for address in bind_addresses:
178 reactor.listenTCP(
179 port,
180 SynapseSite(
181 "synapse.access.http.%s" % (site_tag,),
182 site_tag,
183 listener_config,
184 root_resource,
185 ),
186 interface=address
187 )
188 logger.info("Synapse now listening on port %d", port)
209
210 if name in ["keys", "federation"]:
211 resources.update({
212 SERVER_KEY_PREFIX: LocalKey(self),
213 SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
214 })
215
216 if name == "webclient":
217 resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
218
219 if name == "metrics" and self.get_config().enable_metrics:
220 resources[METRICS_PREFIX] = MetricsResource(self)
221
222 if name == "replication":
223 resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
224
225 return resources
189226
190227 def start_listening(self):
191228 config = self.get_config()
194231 if listener["type"] == "http":
195232 self._listener_http(config, listener)
196233 elif listener["type"] == "manhole":
197 bind_addresses = listener["bind_addresses"]
198
199 for address in bind_addresses:
200 reactor.listenTCP(
201 listener["port"],
202 manhole(
203 username="matrix",
204 password="rabbithole",
205 globals={"hs": self},
206 ),
207 interface=address
234 listen_tcp(
235 listener["bind_addresses"],
236 listener["port"],
237 manhole(
238 username="matrix",
239 password="rabbithole",
240 globals={"hs": self},
208241 )
242 )
209243 elif listener["type"] == "replication":
210244 bind_addresses = listener["bind_addresses"]
211245 for address in bind_addresses:
235269 except IncorrectDatabaseSetup as e:
236270 quit_with_error(e.message)
237271
238 def get_db_conn(self, run_new_connection=True):
239 # Any param beginning with cp_ is a parameter for adbapi, and should
240 # not be passed to the database engine.
241 db_params = {
242 k: v for k, v in self.db_config.get("args", {}).items()
243 if not k.startswith("cp_")
244 }
245 db_conn = self.database_engine.module.connect(**db_params)
246
247 if run_new_connection:
248 self.database_engine.on_new_connection(db_conn)
249 return db_conn
250
251272
252273 def setup(config_options):
253274 """
326347 hs.get_state_handler().start_caching()
327348 hs.get_datastore().start_profiling()
328349 hs.get_datastore().start_doing_background_updates()
329 hs.get_replication_layer().start_get_pdu_cache()
350 hs.get_federation_client().start_get_pdu_cache()
330351
331352 register_memory_metrics(hs)
332353
3434 from synapse.replication.slave.storage.transactions import TransactionStore
3535 from synapse.replication.tcp.client import ReplicationClientHandler
3636 from synapse.rest.media.v0.content_repository import ContentRepoResource
37 from synapse.rest.media.v1.media_repository import MediaRepositoryResource
3837 from synapse.server import HomeServer
3938 from synapse.storage.engines import create_engine
4039 from synapse.storage.media_repository import MediaRepositoryStore
6059
6160
6261 class MediaRepositoryServer(HomeServer):
63 def get_db_conn(self, run_new_connection=True):
64 # Any param beginning with cp_ is a parameter for adbapi, and should
65 # not be passed to the database engine.
66 db_params = {
67 k: v for k, v in self.db_config.get("args", {}).items()
68 if not k.startswith("cp_")
69 }
70 db_conn = self.database_engine.module.connect(**db_params)
71
72 if run_new_connection:
73 self.database_engine.on_new_connection(db_conn)
74 return db_conn
75
7662 def setup(self):
7763 logger.info("Setting up.")
7864 self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
8874 if name == "metrics":
8975 resources[METRICS_PREFIX] = MetricsResource(self)
9076 elif name == "media":
91 media_repo = MediaRepositoryResource(self)
77 media_repo = self.get_media_repository_resource()
9278 resources.update({
9379 MEDIA_PREFIX: media_repo,
9480 LEGACY_MEDIA_PREFIX: media_repo,
9985
10086 root_resource = create_resource_tree(resources, Resource())
10187
102 for address in bind_addresses:
103 reactor.listenTCP(
104 port,
105 SynapseSite(
106 "synapse.access.http.%s" % (site_tag,),
107 site_tag,
108 listener_config,
109 root_resource,
110 ),
111 interface=address
88 _base.listen_tcp(
89 bind_addresses,
90 port,
91 SynapseSite(
92 "synapse.access.http.%s" % (site_tag,),
93 site_tag,
94 listener_config,
95 root_resource,
11296 )
97 )
11398
11499 logger.info("Synapse media repository now listening on port %d", port)
115100
118103 if listener["type"] == "http":
119104 self._listen_http(listener)
120105 elif listener["type"] == "manhole":
121 bind_addresses = listener["bind_addresses"]
122
123 for address in bind_addresses:
124 reactor.listenTCP(
125 listener["port"],
126 manhole(
127 username="matrix",
128 password="rabbithole",
129 globals={"hs": self},
130 ),
131 interface=address
106 _base.listen_tcp(
107 listener["bind_addresses"],
108 listener["port"],
109 manhole(
110 username="matrix",
111 password="rabbithole",
112 globals={"hs": self},
132113 )
114 )
133115 else:
134116 logger.warn("Unrecognized listener type: %s", listener["type"])
135117
150132
151133 assert config.worker_app == "synapse.app.media_repository"
152134
135 if config.enable_media_repo:
136 _base.quit_with_error(
137 "enable_media_repo must be disabled in the main synapse process\n"
138 "before the media repo can be run in a separate worker.\n"
139 "Please add ``enable_media_repo: false`` to the main config\n"
140 )
141
153142 setup_logging(config, use_worker_options=True)
154143
155144 events.USE_FROZEN_DICTS = config.use_frozen_dicts
168157 )
169158
170159 ss.setup()
171 ss.get_handlers()
172160 ss.start_listening(config.worker_listeners)
173161
174162 def start():
3131 from synapse.server import HomeServer
3232 from synapse.storage import DataStore
3333 from synapse.storage.engines import create_engine
34 from synapse.storage.roommember import RoomMemberStore
3534 from synapse.util.httpresourcetree import create_resource_tree
3635 from synapse.util.logcontext import LoggingContext, preserve_fn
3736 from synapse.util.manhole import manhole
7473 DataStore.get_profile_displayname.__func__
7574 )
7675
77 who_forgot_in_room = (
78 RoomMemberStore.__dict__["who_forgot_in_room"]
79 )
80
8176
8277 class PusherServer(HomeServer):
83 def get_db_conn(self, run_new_connection=True):
84 # Any param beginning with cp_ is a parameter for adbapi, and should
85 # not be passed to the database engine.
86 db_params = {
87 k: v for k, v in self.db_config.get("args", {}).items()
88 if not k.startswith("cp_")
89 }
90 db_conn = self.database_engine.module.connect(**db_params)
91
92 if run_new_connection:
93 self.database_engine.on_new_connection(db_conn)
94 return db_conn
95
9678 def setup(self):
9779 logger.info("Setting up.")
9880 self.datastore = PusherSlaveStore(self.get_db_conn(), self)
11395
11496 root_resource = create_resource_tree(resources, Resource())
11597
116 for address in bind_addresses:
117 reactor.listenTCP(
118 port,
119 SynapseSite(
120 "synapse.access.http.%s" % (site_tag,),
121 site_tag,
122 listener_config,
123 root_resource,
124 ),
125 interface=address
98 _base.listen_tcp(
99 bind_addresses,
100 port,
101 SynapseSite(
102 "synapse.access.http.%s" % (site_tag,),
103 site_tag,
104 listener_config,
105 root_resource,
126106 )
107 )
127108
128109 logger.info("Synapse pusher now listening on port %d", port)
129110
132113 if listener["type"] == "http":
133114 self._listen_http(listener)
134115 elif listener["type"] == "manhole":
135 bind_addresses = listener["bind_addresses"]
136
137 for address in bind_addresses:
138 reactor.listenTCP(
139 listener["port"],
140 manhole(
141 username="matrix",
142 password="rabbithole",
143 globals={"hs": self},
144 ),
145 interface=address
116 _base.listen_tcp(
117 listener["bind_addresses"],
118 listener["port"],
119 manhole(
120 username="matrix",
121 password="rabbithole",
122 globals={"hs": self},
146123 )
124 )
147125 else:
148126 logger.warn("Unrecognized listener type: %s", listener["type"])
149127
6161
6262
6363 class SynchrotronSlavedStore(
64 SlavedPushRuleStore,
65 SlavedEventStore,
6664 SlavedReceiptsStore,
6765 SlavedAccountDataStore,
6866 SlavedApplicationServiceStore,
7270 SlavedGroupServerStore,
7371 SlavedDeviceInboxStore,
7472 SlavedDeviceStore,
73 SlavedPushRuleStore,
74 SlavedEventStore,
7575 SlavedClientIpStore,
7676 RoomStore,
7777 BaseSlavedStore,
7878 ):
79 who_forgot_in_room = (
80 RoomMemberStore.__dict__["who_forgot_in_room"]
81 )
82
8379 did_forget = (
8480 RoomMemberStore.__dict__["did_forget"]
8581 )
245241
246242
247243 class SynchrotronServer(HomeServer):
248 def get_db_conn(self, run_new_connection=True):
249 # Any param beginning with cp_ is a parameter for adbapi, and should
250 # not be passed to the database engine.
251 db_params = {
252 k: v for k, v in self.db_config.get("args", {}).items()
253 if not k.startswith("cp_")
254 }
255 db_conn = self.database_engine.module.connect(**db_params)
256
257 if run_new_connection:
258 self.database_engine.on_new_connection(db_conn)
259 return db_conn
260
261244 def setup(self):
262245 logger.info("Setting up.")
263246 self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
287270
288271 root_resource = create_resource_tree(resources, Resource())
289272
290 for address in bind_addresses:
291 reactor.listenTCP(
292 port,
293 SynapseSite(
294 "synapse.access.http.%s" % (site_tag,),
295 site_tag,
296 listener_config,
297 root_resource,
298 ),
299 interface=address
300 )
273 _base.listen_tcp(
274 bind_addresses,
275 port,
276 SynapseSite(
277 "synapse.access.http.%s" % (site_tag,),
278 site_tag,
279 listener_config,
280 root_resource,
281 )
282 )
301283
302284 logger.info("Synapse synchrotron now listening on port %d", port)
303285
306288 if listener["type"] == "http":
307289 self._listen_http(listener)
308290 elif listener["type"] == "manhole":
309 bind_addresses = listener["bind_addresses"]
310
311 for address in bind_addresses:
312 reactor.listenTCP(
313 listener["port"],
314 manhole(
315 username="matrix",
316 password="rabbithole",
317 globals={"hs": self},
318 ),
319 interface=address
291 _base.listen_tcp(
292 listener["bind_addresses"],
293 listener["port"],
294 manhole(
295 username="matrix",
296 password="rabbithole",
297 globals={"hs": self},
320298 )
299 )
321300 else:
322301 logger.warn("Unrecognized listener type: %s", listener["type"])
323302
339318
340319 self.store = hs.get_datastore()
341320 self.typing_handler = hs.get_typing_handler()
321 # NB this is a SynchrotronPresence, not a normal PresenceHandler
342322 self.presence_handler = hs.get_presence_handler()
343323 self.notifier = hs.get_notifier()
344
345 self.presence_handler.sync_callback = self.send_user_sync
346324
347325 def on_rdata(self, stream_name, token, rows):
348326 super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
183183 worker_configfiles.append(worker_configfile)
184184
185185 if options.all_processes:
186 # To start the main synapse with -a you need to add a worker file
187 # with worker_app == "synapse.app.homeserver"
188 start_stop_synapse = False
186189 worker_configdir = options.all_processes
187190 if not os.path.isdir(worker_configdir):
188191 write(
199202 with open(worker_configfile) as stream:
200203 worker_config = yaml.load(stream)
201204 worker_app = worker_config["worker_app"]
202 worker_pidfile = worker_config["worker_pid_file"]
203 worker_daemonize = worker_config["worker_daemonize"]
204 assert worker_daemonize, "In config %r: expected '%s' to be True" % (
205 worker_configfile, "worker_daemonize")
206 worker_cache_factor = worker_config.get("synctl_cache_factor")
205 if worker_app == "synapse.app.homeserver":
206 # We need to special case all of this to pick up options that may
207 # be set in the main config file or in this worker config file.
208 worker_pidfile = (
209 worker_config.get("pid_file")
210 or pidfile
211 )
212 worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
213 daemonize = worker_config.get("daemonize") or config.get("daemonize")
214 assert daemonize, "Main process must have daemonize set to true"
215
216 # The master process doesn't support using worker_* config.
217 for key in worker_config:
218 if key == "worker_app": # But we allow worker_app
219 continue
220 assert not key.startswith("worker_"), \
221 "Main process cannot use worker_* config"
222 else:
223 worker_pidfile = worker_config["worker_pid_file"]
224 worker_daemonize = worker_config["worker_daemonize"]
225 assert worker_daemonize, "In config %r: expected '%s' to be True" % (
226 worker_configfile, "worker_daemonize")
227 worker_cache_factor = worker_config.get("synctl_cache_factor")
207228 workers.append(Worker(
208229 worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
209230 ))
9191
9292
9393 class UserDirectoryServer(HomeServer):
94 def get_db_conn(self, run_new_connection=True):
95 # Any param beginning with cp_ is a parameter for adbapi, and should
96 # not be passed to the database engine.
97 db_params = {
98 k: v for k, v in self.db_config.get("args", {}).items()
99 if not k.startswith("cp_")
100 }
101 db_conn = self.database_engine.module.connect(**db_params)
102
103 if run_new_connection:
104 self.database_engine.on_new_connection(db_conn)
105 return db_conn
106
10794 def setup(self):
10895 logger.info("Setting up.")
10996 self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
130117
131118 root_resource = create_resource_tree(resources, Resource())
132119
133 for address in bind_addresses:
134 reactor.listenTCP(
135 port,
136 SynapseSite(
137 "synapse.access.http.%s" % (site_tag,),
138 site_tag,
139 listener_config,
140 root_resource,
141 ),
142 interface=address
120 _base.listen_tcp(
121 bind_addresses,
122 port,
123 SynapseSite(
124 "synapse.access.http.%s" % (site_tag,),
125 site_tag,
126 listener_config,
127 root_resource,
143128 )
129 )
144130
145131 logger.info("Synapse user_dir now listening on port %d", port)
146132
149135 if listener["type"] == "http":
150136 self._listen_http(listener)
151137 elif listener["type"] == "manhole":
152 bind_addresses = listener["bind_addresses"]
153
154 for address in bind_addresses:
155 reactor.listenTCP(
156 listener["port"],
157 manhole(
158 username="matrix",
159 password="rabbithole",
160 globals={"hs": self},
161 ),
162 interface=address
138 _base.listen_tcp(
139 listener["bind_addresses"],
140 listener["port"],
141 manhole(
142 username="matrix",
143 password="rabbithole",
144 globals={"hs": self},
163145 )
146 )
164147 else:
165148 logger.warn("Unrecognized listener type: %s", listener["type"])
166149
1313 # limitations under the License.
1414 from synapse.api.constants import EventTypes
1515 from synapse.util.caches.descriptors import cachedInlineCallbacks
16 from synapse.types import GroupID, get_domain_from_id
1617
1718 from twisted.internet import defer
1819
8081 # values.
8182 NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
8283
83 def __init__(self, token, url=None, namespaces=None, hs_token=None,
84 def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
8485 sender=None, id=None, protocols=None, rate_limited=True):
8586 self.token = token
8687 self.url = url
8788 self.hs_token = hs_token
8889 self.sender = sender
90 self.server_name = hostname
8991 self.namespaces = self._check_namespaces(namespaces)
9092 self.id = id
9193
124126 raise ValueError(
125127 "Expected bool for 'exclusive' in ns '%s'" % ns
126128 )
129 group_id = regex_obj.get("group_id")
130 if group_id:
131 if not isinstance(group_id, str):
132 raise ValueError(
133 "Expected string for 'group_id' in ns '%s'" % ns
134 )
135 try:
136 GroupID.from_string(group_id)
137 except Exception:
138 raise ValueError(
139 "Expected valid group ID for 'group_id' in ns '%s'" % ns
140 )
141
142 if get_domain_from_id(group_id) != self.server_name:
143 raise ValueError(
144 "Expected 'group_id' to be this host in ns '%s'" % ns
145 )
146
127147 regex = regex_obj.get("regex")
128148 if isinstance(regex, basestring):
129149 regex_obj["regex"] = re.compile(regex) # Pre-compile regex
250270 if regex_obj["exclusive"]
251271 ]
252272
273 def get_groups_for_user(self, user_id):
274 """Get the groups that this user is associated with by this AS
275
276 Args:
277 user_id (str): The ID of the user.
278
279 Returns:
280 iterable[str]: an iterable that yields group_id strings.
281 """
282 return (
283 regex_obj["group_id"]
284 for regex_obj in self.namespaces[ApplicationService.NS_USERS]
285 if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
286 )
287
253288 def is_rate_limited(self):
254289 return self.rate_limited
255290
1717 from synapse.api.errors import CodeMessageException
1818 from synapse.http.client import SimpleHttpClient
1919 from synapse.events.utils import serialize_event
20 from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
2021 from synapse.util.caches.response_cache import ResponseCache
2122 from synapse.types import ThirdPartyInstanceID
2223
191192 defer.returnValue(None)
192193
193194 key = (service.id, protocol)
194 return self.protocol_meta_cache.get(key) or (
195 self.protocol_meta_cache.set(key, _get())
196 )
195 result = self.protocol_meta_cache.get(key)
196 if not result:
197 result = self.protocol_meta_cache.set(
198 key, preserve_fn(_get)()
199 )
200 return make_deferred_yieldable(result)
197201
198202 @defer.inlineCallbacks
199203 def push_bulk(self, service, events, txn_id=None):
122122 with Measure(self.clock, "servicequeuer.send"):
123123 try:
124124 yield self.txn_ctrl.send(service, events)
125 except:
125 except Exception:
126126 logger.exception("AS request failed")
127127 finally:
128128 self.requests_in_flight.discard(service.id)
153153 )
154154 return ApplicationService(
155155 token=as_info["as_token"],
156 hostname=hostname,
156157 url=as_info["url"],
157158 namespaces=as_info["namespaces"],
158159 hs_token=as_info["hs_token"],
4040 #cas_config:
4141 # enabled: true
4242 # server_url: "https://cas-server.com"
43 # service_url: "https://homesever.domain.com:8448"
43 # service_url: "https://homeserver.domain.com:8448"
4444 # #required_attributes:
4545 # # name: value
4646 """
3535 from .push import PushConfig
3636 from .spam_checker import SpamCheckerConfig
3737 from .groups import GroupsConfig
38 from .user_directory import UserDirectoryConfig
3839
3940
4041 class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
4344 AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
4445 JWTConfig, PasswordConfig, EmailConfig,
4546 WorkerConfig, PasswordAuthProviderConfig, PushConfig,
46 SpamCheckerConfig, GroupsConfig,):
47 SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
4748 pass
4849
4950
2727 version: 1
2828
2929 formatters:
30 precise:
31 format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
32 - %(message)s'
30 precise:
31 format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
32 %(request)s - %(message)s'
3333
3434 filters:
35 context:
36 (): synapse.util.logcontext.LoggingContextFilter
37 request: ""
35 context:
36 (): synapse.util.logcontext.LoggingContextFilter
37 request: ""
3838
3939 handlers:
40 file:
41 class: logging.handlers.RotatingFileHandler
42 formatter: precise
43 filename: ${log_file}
44 maxBytes: 104857600
45 backupCount: 10
46 filters: [context]
47 console:
48 class: logging.StreamHandler
49 formatter: precise
50 filters: [context]
40 file:
41 class: logging.handlers.RotatingFileHandler
42 formatter: precise
43 filename: ${log_file}
44 maxBytes: 104857600
45 backupCount: 10
46 filters: [context]
47 console:
48 class: logging.StreamHandler
49 formatter: precise
50 filters: [context]
5151
5252 loggers:
5353 synapse:
7373 self.log_file = self.abspath(config.get("log_file"))
7474
7575 def default_config(self, config_dir_path, server_name, **kwargs):
76 log_file = self.abspath("homeserver.log")
7776 log_config = self.abspath(
7877 os.path.join(config_dir_path, server_name + ".log.config")
7978 )
8079 return """
81 # Logging verbosity level. Ignored if log_config is specified.
82 verbose: 0
83
84 # File to write logging to. Ignored if log_config is specified.
85 log_file: "%(log_file)s"
86
8780 # A yaml python logging config file
8881 log_config: "%(log_config)s"
8982 """ % locals()
122115 def generate_files(self, config):
123116 log_config = config.get("log_config")
124117 if log_config and not os.path.exists(log_config):
118 log_file = self.abspath("homeserver.log")
125119 with open(log_config, "wb") as log_config_file:
126120 log_config_file.write(
127 DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
121 DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
128122 )
129123
130124
147141 "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
148142 " - %(message)s"
149143 )
144
150145 if log_config is None:
151
146 # We don't have a logfile, so fall back to the 'verbosity' param from
147 # the config or cmdline. (Note that we generate a log config for new
148 # installs, so this will be an unusual case)
152149 level = logging.INFO
153150 level_for_storage = logging.INFO
154151 if config.verbosity:
156153 if config.verbosity > 1:
157154 level_for_storage = logging.DEBUG
158155
159 # FIXME: we need a logging.WARN for a -q quiet option
160156 logger = logging.getLogger('')
161157 logger.setLevel(level)
162158
163 logging.getLogger('synapse.storage').setLevel(level_for_storage)
159 logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
164160
165161 formatter = logging.Formatter(log_format)
166162 if log_file:
175171 logger.info("Opened new log file due to SIGHUP")
176172 else:
177173 handler = logging.StreamHandler()
174
175 def sighup(signum, stack):
176 pass
177
178178 handler.setFormatter(formatter)
179179
180180 handler.addFilter(LoggingContextFilter(request=""))
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from ._base import Config, ConfigError
15 from ._base import Config
1616
1717 from synapse.util.module_loader import load_module
18
19 LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
1820
1921
2022 class PasswordAuthProviderConfig(Config):
2123 def read_config(self, config):
2224 self.password_providers = []
23
24 provider_config = None
25 providers = []
2526
2627 # We want to be backwards compatible with the old `ldap_config`
2728 # param.
2829 ldap_config = config.get("ldap_config", {})
29 self.ldap_enabled = ldap_config.get("enabled", False)
30 if self.ldap_enabled:
31 from ldap_auth_provider import LdapAuthProvider
32 parsed_config = LdapAuthProvider.parse_config(ldap_config)
33 self.password_providers.append((LdapAuthProvider, parsed_config))
30 if ldap_config.get("enabled", False):
31 providers.append({
32 'module': LDAP_PROVIDER,
33 'config': ldap_config,
34 })
3435
35 providers = config.get("password_providers", [])
36 providers.extend(config.get("password_providers", []))
3637 for provider in providers:
38 mod_name = provider['module']
39
3740 # This is for backwards compat when the ldap auth provider resided
3841 # in this package.
39 if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
40 from ldap_auth_provider import LdapAuthProvider
41 provider_class = LdapAuthProvider
42 try:
43 provider_config = provider_class.parse_config(provider["config"])
44 except Exception as e:
45 raise ConfigError(
46 "Failed to parse config for %r: %r" % (provider['module'], e)
47 )
48 else:
49 (provider_class, provider_config) = load_module(provider)
42 if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
43 mod_name = LDAP_PROVIDER
44
45 (provider_class, provider_config) = load_module({
46 "module": mod_name,
47 "config": provider['config'],
48 })
5049
5150 self.password_providers.append((provider_class, provider_config))
5251
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2017 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1718
1819 class PushConfig(Config):
1920 def read_config(self, config):
20 self.push_redact_content = False
21 push_config = config.get("push", {})
22 self.push_include_content = push_config.get("include_content", True)
2123
24 # There was a a 'redact_content' setting but mistakenly read from the
25 # 'email'section'. Check for the flag in the 'push' section, and log,
26 # but do not honour it to avoid nasty surprises when people upgrade.
27 if push_config.get("redact_content") is not None:
28 print(
29 "The push.redact_content content option has never worked. "
30 "Please set push.include_content if you want this behaviour"
31 )
32
33 # Now check for the one in the 'email' section and honour it,
34 # with a warning.
2235 push_config = config.get("email", {})
23 self.push_redact_content = push_config.get("redact_content", False)
36 redact_content = push_config.get("redact_content")
37 if redact_content is not None:
38 print(
39 "The 'email.redact_content' option is deprecated: "
40 "please set push.include_content instead"
41 )
42 self.push_include_content = not redact_content
2443
2544 def default_config(self, config_dir_path, server_name, **kwargs):
2645 return """
27 # Control how push messages are sent to google/apple to notifications.
28 # Normally every message said in a room with one or more people using
29 # mobile devices will be posted to a push server hosted by matrix.org
30 # which is registered with google and apple in order to allow push
31 # notifications to be sent to these mobile devices.
32 #
33 # Setting redact_content to true will make the push messages contain no
34 # message content which will provide increased privacy. This is a
35 # temporary solution pending improvements to Android and iPhone apps
36 # to get content from the app rather than the notification.
37 #
46 # Clients requesting push notifications can either have the body of
47 # the message sent in the notification poke along with other details
48 # like the sender, or just the event ID and room ID (`event_id_only`).
49 # If clients choose the former, this option controls whether the
50 # notification request includes the content of the event (other details
51 # like the sender are still included). For `event_id_only` push, it
52 # has no effect.
53
3854 # For modern android devices the notification content will still appear
3955 # because it is loaded by the app. iPhone, however will send a
4056 # notification saying only that a message arrived and who it came from.
4157 #
4258 #push:
43 # redact_content: false
59 # include_content: true
4460 """
3030 strtobool(str(config["disable_registration"]))
3131 )
3232
33 self.registrations_require_3pid = config.get("registrations_require_3pid", [])
34 self.allowed_local_3pids = config.get("allowed_local_3pids", [])
3335 self.registration_shared_secret = config.get("registration_shared_secret")
3436
3537 self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
5153 # Enable registration for new users.
5254 enable_registration: False
5355
56 # The user must provide all of the below types of 3PID when registering.
57 #
58 # registrations_require_3pid:
59 # - email
60 # - msisdn
61
62 # Mandate that users are only allowed to associate certain formats of
63 # 3PIDs with accounts on this server.
64 #
65 # allowed_local_3pids:
66 # - medium: email
67 # pattern: ".*@matrix\\.org"
68 # - medium: email
69 # pattern: ".*@vector\\.im"
70 # - medium: msisdn
71 # pattern: "\\+44"
72
5473 # If set, allows registration by anyone who also has the shared
5574 # secret, even if registration is otherwise disabled.
5675 registration_shared_secret: "%(registration_shared_secret)s"
5776
5877 # Set the number of bcrypt rounds used to generate password hash.
5978 # Larger numbers increase the work factor needed to generate the hash.
60 # The default number of rounds is 12.
79 # The default number is 12 (which equates to 2^12 rounds).
80 # N.B. that increasing this will exponentially increase the time required
81 # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
6182 bcrypt_rounds: 12
6283
6384 # Allows users to register as guests without a password/email/etc, and
1515 from ._base import Config, ConfigError
1616 from collections import namedtuple
1717
18 from synapse.util.module_loader import load_module
19
1820
1921 MISSING_NETADDR = (
2022 "Missing netaddr library. This is required for URL preview API."
3335
3436 ThumbnailRequirement = namedtuple(
3537 "ThumbnailRequirement", ["width", "height", "method", "media_type"]
38 )
39
40 MediaStorageProviderConfig = namedtuple(
41 "MediaStorageProviderConfig", (
42 "store_local", # Whether to store newly uploaded local files
43 "store_remote", # Whether to store newly downloaded remote files
44 "store_synchronous", # Whether to wait for successful storage for local uploads
45 ),
3646 )
3747
3848
7282
7383 self.media_store_path = self.ensure_directory(config["media_store_path"])
7484
75 self.backup_media_store_path = config.get("backup_media_store_path")
76 if self.backup_media_store_path:
77 self.backup_media_store_path = self.ensure_directory(
78 self.backup_media_store_path
79 )
80
81 self.synchronous_backup_media_store = config.get(
85 backup_media_store_path = config.get("backup_media_store_path")
86
87 synchronous_backup_media_store = config.get(
8288 "synchronous_backup_media_store", False
8389 )
90
91 storage_providers = config.get("media_storage_providers", [])
92
93 if backup_media_store_path:
94 if storage_providers:
95 raise ConfigError(
96 "Cannot use both 'backup_media_store_path' and 'storage_providers'"
97 )
98
99 storage_providers = [{
100 "module": "file_system",
101 "store_local": True,
102 "store_synchronous": synchronous_backup_media_store,
103 "store_remote": True,
104 "config": {
105 "directory": backup_media_store_path,
106 }
107 }]
108
109 # This is a list of config that can be used to create the storage
110 # providers. The entries are tuples of (Class, class_config,
111 # MediaStorageProviderConfig), where Class is the class of the provider,
112 # the class_config the config to pass to it, and
113 # MediaStorageProviderConfig are options for StorageProviderWrapper.
114 #
115 # We don't create the storage providers here as not all workers need
116 # them to be started.
117 self.media_storage_providers = []
118
119 for provider_config in storage_providers:
120 # We special case the module "file_system" so as not to need to
121 # expose FileStorageProviderBackend
122 if provider_config["module"] == "file_system":
123 provider_config["module"] = (
124 "synapse.rest.media.v1.storage_provider"
125 ".FileStorageProviderBackend"
126 )
127
128 provider_class, parsed_config = load_module(provider_config)
129
130 wrapper_config = MediaStorageProviderConfig(
131 provider_config.get("store_local", False),
132 provider_config.get("store_remote", False),
133 provider_config.get("store_synchronous", False),
134 )
135
136 self.media_storage_providers.append(
137 (provider_class, parsed_config, wrapper_config,)
138 )
84139
85140 self.uploads_path = self.ensure_directory(config["uploads_path"])
86141 self.dynamic_thumbnails = config["dynamic_thumbnails"]
126181 # Directory where uploaded images and attachments are stored.
127182 media_store_path: "%(media_store)s"
128183
129 # A secondary directory where uploaded images and attachments are
130 # stored as a backup.
131 # backup_media_store_path: "%(media_store)s"
132
133 # Whether to wait for successful write to backup media store before
134 # returning successfully.
135 # synchronous_backup_media_store: false
184 # Media storage providers allow media to be stored in different
185 # locations.
186 # media_storage_providers:
187 # - module: file_system
188 # # Whether to write new local files.
189 # store_local: false
190 # # Whether to write new remote media
191 # store_remote: false
192 # # Whether to block upload requests waiting for write to this
193 # # provider to complete
194 # store_synchronous: false
195 # config:
196 # directory: /mnt/some/other/directory
136197
137198 # Directory where in-progress uploads are stored.
138199 uploads_path: "%(uploads_path)s"
4040 # false only if we are updating the user directory in a worker
4141 self.update_user_directory = config.get("update_user_directory", True)
4242
43 # whether to enable the media repository endpoints. This should be set
44 # to false if the media repository is running as a separate endpoint;
45 # doing so ensures that we will not run cache cleanup jobs on the
46 # master, potentially causing inconsistency.
47 self.enable_media_repo = config.get("enable_media_repo", True)
48
4349 self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
4450
4551 # Whether we should block invites sent to users on this server
4753 self.block_non_admin_invites = config.get(
4854 "block_non_admin_invites", False,
4955 )
56
57 # FIXME: federation_domain_whitelist needs sytests
58 self.federation_domain_whitelist = None
59 federation_domain_whitelist = config.get(
60 "federation_domain_whitelist", None
61 )
62 # turn the whitelist into a hash for speed of lookup
63 if federation_domain_whitelist is not None:
64 self.federation_domain_whitelist = {}
65 for domain in federation_domain_whitelist:
66 self.federation_domain_whitelist[domain] = True
5067
5168 if self.public_baseurl is not None:
5269 if self.public_baseurl[-1] != '/':
203220 # (except those sent by local server admins). The default is False.
204221 # block_non_admin_invites: True
205222
223 # Restrict federation to the following whitelist of domains.
224 # N.B. we recommend also firewalling your federation listener to limit
225 # inbound federation traffic as early as possible, rather than relying
226 # purely on this application-layer restriction. If not specified, the
227 # default is to whitelist everything.
228 #
229 # federation_domain_whitelist:
230 # - lon.example.com
231 # - nyc.example.com
232 # - syd.example.com
233
206234 # List of ports that Synapse should listen on, their purpose and their
207235 # configuration.
208236 listeners:
213241 port: %(bind_port)s
214242
215243 # Local addresses to listen on.
216 # This will listen on all IPv4 addresses by default.
244 # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
245 # addresses by default. For most other OSes, this will only listen
246 # on IPv6.
217247 bind_addresses:
248 - '::'
218249 - '0.0.0.0'
219 # Uncomment to listen on all IPv6 interfaces
220 # N.B: On at least Linux this will also listen on all IPv4
221 # addresses, so you will need to comment out the line above.
222 # - '::'
223250
224251 # This is a 'http' listener, allows us to specify 'resources'.
225252 type: http
246273 - names: [federation] # Federation APIs
247274 compress: false
248275
276 # optional list of additional endpoints which can be loaded via
277 # dynamic modules
278 # additional_resources:
279 # "/_matrix/my/custom/endpoint":
280 # module: my_module.CustomRequestHandler
281 # config: {}
282
249283 # Unsecure HTTP listener,
250284 # For when matrix traffic passes through loadbalancer that unwraps TLS.
251285 - port: %(unsecure_port)s
252286 tls: false
253 bind_addresses: ['0.0.0.0']
287 bind_addresses: ['::', '0.0.0.0']
254288 type: http
255289
256290 x_forwarded: false
264298 # Turn on the twisted ssh manhole service on localhost on the given
265299 # port.
266300 # - port: 9000
267 # bind_address: 127.0.0.1
301 # bind_addresses: ['::1', '127.0.0.1']
268302 # type: manhole
269303 """ % locals()
270304
302336 return (
303337 int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
304338 )
305 except:
339 except Exception:
306340 raise ConfigError(
307341 "Value of `gc_threshold` must be a list of three integers if set"
308342 )
9595 # certificates returned by this server match one of the fingerprints.
9696 #
9797 # Synapse automatically adds the fingerprint of its own certificate
98 # to the list. So if federation traffic is handle directly by synapse
98 # to the list. So if federation traffic is handled directly by synapse
9999 # then no modification to the list is required.
100100 #
101101 # If synapse is run behind a load balancer that handles the TLS then it
108108 # key. It may be necessary to publish the fingerprints of a new
109109 # certificate and wait until the "valid_until_ts" of the previous key
110110 # responses have passed before deploying it.
111 #
112 # You can calculate a fingerprint from a given TLS listener via:
113 # openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
114 # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
115 # or by checking matrix.org/federationtester/api/report?server_name=$host
116 #
111117 tls_fingerprints: []
112118 # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
113119 """ % locals()
0 # -*- coding: utf-8 -*-
1 # Copyright 2017 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from ._base import Config
16
17
18 class UserDirectoryConfig(Config):
19 """User Directory Configuration
20 Configuration for the behaviour of the /user_directory API
21 """
22
23 def read_config(self, config):
24 self.user_directory_search_all_users = False
25 user_directory_config = config.get("user_directory", None)
26 if user_directory_config:
27 self.user_directory_search_all_users = (
28 user_directory_config.get("search_all_users", False)
29 )
30
31 def default_config(self, config_dir_path, server_name, **kwargs):
32 return """
33 # User Directory configuration
34 #
35 # 'search_all_users' defines whether to search all users visible to your HS
36 # when searching the user directory, rather than limiting to users visible
37 # in public rooms. Defaults to false. If you set it True, you'll have to run
38 # UPDATE user_directory_stream_pos SET stream_id = NULL;
39 # on your database to tell it to rebuild the user_directory search indexes.
40 #
41 #user_directory:
42 # search_all_users: false
43 """
2222
2323 def read_config(self, config):
2424 self.worker_app = config.get("worker_app")
25
26 # Canonicalise worker_app so that master always has None
27 if self.worker_app == "synapse.app.homeserver":
28 self.worker_app = None
29
2530 self.worker_listeners = config.get("worker_listeners")
2631 self.worker_daemonize = config.get("worker_daemonize")
2732 self.worker_pid_file = config.get("worker_pid_file")
2833 self.worker_log_file = config.get("worker_log_file")
2934 self.worker_log_config = config.get("worker_log_config")
35
36 # The host used to connect to the main synapse
3037 self.worker_replication_host = config.get("worker_replication_host", None)
38
39 # The port on the main synapse for TCP replication
3140 self.worker_replication_port = config.get("worker_replication_port", None)
41
42 # The port on the main synapse for HTTP replication endpoint
43 self.worker_replication_http_port = config.get("worker_replication_http_port")
44
3245 self.worker_name = config.get("worker_name", self.worker_app)
3346
3447 self.worker_main_http_uri = config.get("worker_main_http_uri", None)
3333 try:
3434 _ecCurve = _OpenSSLECCurve(_defaultCurveName)
3535 _ecCurve.addECKeyToContext(context)
36 except:
36 except Exception:
3737 logger.exception("Failed to enable elliptic curve for TLS")
3838 context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
3939 context.use_certificate_chain_file(config.tls_certificate_file)
3131 """Check whether the hash for this PDU matches the contents"""
3232 name, expected_hash = compute_content_hash(event, hash_algorithm)
3333 logger.debug("Expecting hash: %s", encode_base64(expected_hash))
34 if name not in event.hashes:
34
35 # some malformed events lack a 'hashes'. Protect against it being missing
36 # or a weird type by basically treating it the same as an unhashed event.
37 hashes = event.get("hashes")
38 if not isinstance(hashes, dict):
39 raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
40
41 if name not in hashes:
3542 raise SynapseError(
3643 400,
3744 "Algorithm %s not in hashes %s" % (
38 name, list(event.hashes),
45 name, list(hashes),
3946 ),
4047 Codes.UNAUTHORIZED,
4148 )
42 message_hash_base64 = event.hashes[name]
49 message_hash_base64 = hashes[name]
4350 try:
4451 message_hash_bytes = decode_base64(message_hash_base64)
45 except:
52 except Exception:
4653 raise SynapseError(
4754 400,
4855 "Invalid base64: %s" % (message_hash_base64,),
758758 ))
759759 try:
760760 verify_signed_json(json_object, server_name, verify_key)
761 except:
761 except Exception:
762762 raise SynapseError(
763763 401,
764764 "Invalid signature for server %s with key %s:%s" % (
318318 # TODO (erikj): Implement kicks.
319319 if target_banned and user_level < ban_level:
320320 raise AuthError(
321 403, "You cannot unban user &s." % (target_user_id,)
321 403, "You cannot unban user %s." % (target_user_id,)
322322 )
323323 elif target_user_id != event.user_id:
324324 kick_level = _get_named_level(auth_events, "kick", 50)
442442 for k, v in user_list.items():
443443 try:
444444 UserID.from_string(k)
445 except:
445 except Exception:
446446 raise SynapseError(400, "Not a valid user_id: %s" % (k,))
447447
448448 try:
449449 int(v)
450 except:
450 except Exception:
451451 raise SynapseError(400, "Not a valid power level: %s" % (v,))
452452
453453 key = (event.type, event.state_key, )
5454
5555 local_part = str(int(self.clock.time())) + i + random_string(5)
5656
57 e_id = EventID.create(local_part, self.hostname)
57 e_id = EventID(local_part, self.hostname)
5858
5959 return e_id.to_string()
6060
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from twisted.internet import defer
16
17 from frozendict import frozendict
18
1519
1620 class EventContext(object):
1721 """
2428 The current state map excluding the current event.
2529 (type, state_key) -> event_id
2630
27 state_group (int): state group id
31 state_group (int|None): state group id, if the state has been stored
32 as a state group. This is usually only None if e.g. the event is
33 an outlier.
2834 rejected (bool|str): A rejection reason if the event was rejected, else
2935 False
3036
4551 "prev_state_ids",
4652 "state_group",
4753 "rejected",
48 "push_actions",
4954 "prev_group",
5055 "delta_ids",
5156 "prev_state_events",
6065 self.state_group = None
6166
6267 self.rejected = False
63 self.push_actions = []
6468
6569 # A previously persisted state group and a delta between that
6670 # and this state.
7074 self.prev_state_events = None
7175
7276 self.app_service = None
77
78 def serialize(self, event):
79 """Converts self to a type that can be serialized as JSON, and then
80 deserialized by `deserialize`
81
82 Args:
83 event (FrozenEvent): The event that this context relates to
84
85 Returns:
86 dict
87 """
88
89 # We don't serialize the full state dicts, instead they get pulled out
90 # of the DB on the other side. However, the other side can't figure out
91 # the prev_state_ids, so if we're a state event we include the event
92 # id that we replaced in the state.
93 if event.is_state():
94 prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
95 else:
96 prev_state_id = None
97
98 return {
99 "prev_state_id": prev_state_id,
100 "event_type": event.type,
101 "event_state_key": event.state_key if event.is_state() else None,
102 "state_group": self.state_group,
103 "rejected": self.rejected,
104 "prev_group": self.prev_group,
105 "delta_ids": _encode_state_dict(self.delta_ids),
106 "prev_state_events": self.prev_state_events,
107 "app_service_id": self.app_service.id if self.app_service else None
108 }
109
110 @staticmethod
111 @defer.inlineCallbacks
112 def deserialize(store, input):
113 """Converts a dict that was produced by `serialize` back into a
114 EventContext.
115
116 Args:
117 store (DataStore): Used to convert AS ID to AS object
118 input (dict): A dict produced by `serialize`
119
120 Returns:
121 EventContext
122 """
123 context = EventContext()
124 context.state_group = input["state_group"]
125 context.rejected = input["rejected"]
126 context.prev_group = input["prev_group"]
127 context.delta_ids = _decode_state_dict(input["delta_ids"])
128 context.prev_state_events = input["prev_state_events"]
129
130 # We use the state_group and prev_state_id stuff to pull the
131 # current_state_ids out of the DB and construct prev_state_ids.
132 prev_state_id = input["prev_state_id"]
133 event_type = input["event_type"]
134 event_state_key = input["event_state_key"]
135
136 context.current_state_ids = yield store.get_state_ids_for_group(
137 context.state_group,
138 )
139 if prev_state_id and event_state_key:
140 context.prev_state_ids = dict(context.current_state_ids)
141 context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
142 else:
143 context.prev_state_ids = context.current_state_ids
144
145 app_service_id = input["app_service_id"]
146 if app_service_id:
147 context.app_service = store.get_app_service_by_id(app_service_id)
148
149 defer.returnValue(context)
150
151
152 def _encode_state_dict(state_dict):
153 """Since dicts of (type, state_key) -> event_id cannot be serialized in
154 JSON we need to convert them to a form that can.
155 """
156 if state_dict is None:
157 return None
158
159 return [
160 (etype, state_key, v)
161 for (etype, state_key), v in state_dict.iteritems()
162 ]
163
164
165 def _decode_state_dict(input):
166 """Decodes a state dict encoded using `_encode_state_dict` above
167 """
168 if input is None:
169 return None
170
171 return frozendict({(etype, state_key,): v for etype, state_key, v in input})
2121 config = None
2222 try:
2323 module, config = hs.config.spam_checker
24 except:
24 except Exception:
2525 pass
2626
2727 if module is not None:
1414
1515 """ This package includes all the federation specific logic.
1616 """
17
18 from .replication import ReplicationLayer
19
20
21 def initialize_http_replication(hs):
22 transport = hs.get_federation_transport_client()
23
24 return ReplicationLayer(hs, transport)
1515
1616 from synapse.api.errors import SynapseError
1717 from synapse.crypto.event_signing import check_event_content_hash
18 from synapse.events import FrozenEvent
1819 from synapse.events.utils import prune_event
20 from synapse.http.servlet import assert_params_in_request
1921 from synapse.util import unwrapFirstError, logcontext
2022 from twisted.internet import defer
2123
2426
2527 class FederationBase(object):
2628 def __init__(self, hs):
29 self.hs = hs
30
31 self.server_name = hs.hostname
32 self.keyring = hs.get_keyring()
2733 self.spam_checker = hs.get_spam_checker()
34 self.store = hs.get_datastore()
35 self._clock = hs.get_clock()
2836
2937 @defer.inlineCallbacks
3038 def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
168176 )
169177
170178 return deferreds
179
180
181 def event_from_pdu_json(pdu_json, outlier=False):
182 """Construct a FrozenEvent from an event json received over federation
183
184 Args:
185 pdu_json (object): pdu as received over federation
186 outlier (bool): True to mark this event as an outlier
187
188 Returns:
189 FrozenEvent
190
191 Raises:
192 SynapseError: if the pdu is missing required fields
193 """
194 # we could probably enforce a bunch of other fields here (room_id, sender,
195 # origin, etc etc)
196 assert_params_in_request(pdu_json, ('event_id', 'type'))
197 event = FrozenEvent(
198 pdu_json
199 )
200
201 event.internal_metadata.outlier = outlier
202
203 return event
1313 # limitations under the License.
1414
1515
16 from twisted.internet import defer
17
18 from .federation_base import FederationBase
19 from synapse.api.constants import Membership
20
21 from synapse.api.errors import (
22 CodeMessageException, HttpResponseException, SynapseError,
23 )
24 from synapse.util import unwrapFirstError, logcontext
25 from synapse.util.caches.expiringcache import ExpiringCache
26 from synapse.util.logutils import log_function
27 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
28 from synapse.events import FrozenEvent, builder
29 import synapse.metrics
30
31 from synapse.util.retryutils import NotRetryingDestination
32
3316 import copy
3417 import itertools
3518 import logging
3619 import random
3720
21 from twisted.internet import defer
22
23 from synapse.api.constants import Membership
24 from synapse.api.errors import (
25 CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
26 )
27 from synapse.events import builder
28 from synapse.federation.federation_base import (
29 FederationBase,
30 event_from_pdu_json,
31 )
32 import synapse.metrics
33 from synapse.util import logcontext, unwrapFirstError
34 from synapse.util.caches.expiringcache import ExpiringCache
35 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
36 from synapse.util.logutils import log_function
37 from synapse.util.retryutils import NotRetryingDestination
3838
3939 logger = logging.getLogger(__name__)
4040
5757 self._clear_tried_cache, 60 * 1000,
5858 )
5959 self.state = hs.get_state_handler()
60 self.transport_layer = hs.get_federation_transport_client()
6061
6162 def _clear_tried_cache(self):
6263 """Clear pdu_destination_tried cache"""
183184 logger.debug("backfill transaction_data=%s", repr(transaction_data))
184185
185186 pdus = [
186 self.event_from_pdu_json(p, outlier=False)
187 event_from_pdu_json(p, outlier=False)
187188 for p in transaction_data["pdus"]
188189 ]
189190
243244 logger.debug("transaction_data %r", transaction_data)
244245
245246 pdu_list = [
246 self.event_from_pdu_json(p, outlier=outlier)
247 event_from_pdu_json(p, outlier=outlier)
247248 for p in transaction_data["pdus"]
248249 ]
249250
263264 event_id, destination, e,
264265 )
265266 except NotRetryingDestination as e:
267 logger.info(e.message)
268 continue
269 except FederationDeniedError as e:
266270 logger.info(e.message)
267271 continue
268272 except Exception as e:
335339 )
336340
337341 pdus = [
338 self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
342 event_from_pdu_json(p, outlier=True) for p in result["pdus"]
339343 ]
340344
341345 auth_chain = [
342 self.event_from_pdu_json(p, outlier=True)
346 event_from_pdu_json(p, outlier=True)
343347 for p in result.get("auth_chain", [])
344348 ]
345349
419423 for e_id in batch
420424 ]
421425
422 res = yield preserve_context_over_deferred(
426 res = yield make_deferred_yieldable(
423427 defer.DeferredList(deferreds, consumeErrors=True)
424428 )
425429 for success, result in res:
440444 )
441445
442446 auth_chain = [
443 self.event_from_pdu_json(p, outlier=True)
447 event_from_pdu_json(p, outlier=True)
444448 for p in res["auth_chain"]
445449 ]
446450
569573 logger.debug("Got content: %s", content)
570574
571575 state = [
572 self.event_from_pdu_json(p, outlier=True)
576 event_from_pdu_json(p, outlier=True)
573577 for p in content.get("state", [])
574578 ]
575579
576580 auth_chain = [
577 self.event_from_pdu_json(p, outlier=True)
581 event_from_pdu_json(p, outlier=True)
578582 for p in content.get("auth_chain", [])
579583 ]
580584
649653
650654 logger.debug("Got response to send_invite: %s", pdu_dict)
651655
652 pdu = self.event_from_pdu_json(pdu_dict)
656 pdu = event_from_pdu_json(pdu_dict)
653657
654658 # Check signatures are correct.
655659 pdu = yield self._check_sigs_and_hash(pdu)
739743 )
740744
741745 auth_chain = [
742 self.event_from_pdu_json(e)
746 event_from_pdu_json(e)
743747 for e in content["auth_chain"]
744748 ]
745749
787791 )
788792
789793 events = [
790 self.event_from_pdu_json(e)
794 event_from_pdu_json(e)
791795 for e in content.get("events", [])
792796 ]
793797
803807 signed_events = []
804808
805809 defer.returnValue(signed_events)
806
807 def event_from_pdu_json(self, pdu_json, outlier=False):
808 event = FrozenEvent(
809 pdu_json
810 )
811
812 event.internal_metadata.outlier = outlier
813
814 return event
815810
816811 @defer.inlineCallbacks
817812 def forward_third_party_invite(self, destinations, room_id, event_dict):
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import logging
15
16 import simplejson as json
1417 from twisted.internet import defer
1518
16 from .federation_base import FederationBase
17 from .units import Transaction, Edu
18
19 from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
20 from synapse.crypto.event_signing import compute_event_signature
21 from synapse.federation.federation_base import (
22 FederationBase,
23 event_from_pdu_json,
24 )
25
26 from synapse.federation.persistence import TransactionActions
27 from synapse.federation.units import Edu, Transaction
28 import synapse.metrics
29 from synapse.types import get_domain_from_id
1930 from synapse.util import async
31 from synapse.util.caches.response_cache import ResponseCache
32 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
2033 from synapse.util.logutils import log_function
21 from synapse.util.caches.response_cache import ResponseCache
22 from synapse.events import FrozenEvent
23 from synapse.types import get_domain_from_id
24 import synapse.metrics
25
26 from synapse.api.errors import AuthError, FederationError, SynapseError
27
28 from synapse.crypto.event_signing import compute_event_signature
29
30 import simplejson as json
31 import logging
3234
3335 # when processing incoming transactions, we try to handle multiple rooms in
3436 # parallel, up to this limit.
5153 super(FederationServer, self).__init__(hs)
5254
5355 self.auth = hs.get_auth()
56 self.handler = hs.get_handlers().federation_handler
5457
5558 self._server_linearizer = async.Linearizer("fed_server")
5659 self._transaction_linearizer = async.Linearizer("fed_txn_handler")
60
61 self.transaction_actions = TransactionActions(self.store)
62
63 self.registry = hs.get_federation_registry()
5764
5865 # We cache responses to state queries, as they take a while and often
5966 # come in waves.
6067 self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
61
62 def set_handler(self, handler):
63 """Sets the handler that the replication layer will use to communicate
64 receipt of new PDUs from other home servers. The required methods are
65 documented on :py:class:`.ReplicationHandler`.
66 """
67 self.handler = handler
68
69 def register_edu_handler(self, edu_type, handler):
70 if edu_type in self.edu_handlers:
71 raise KeyError("Already have an EDU handler for %s" % (edu_type,))
72
73 self.edu_handlers[edu_type] = handler
74
75 def register_query_handler(self, query_type, handler):
76 """Sets the handler callable that will be used to handle an incoming
77 federation Query of the given type.
78
79 Args:
80 query_type (str): Category name of the query, which should match
81 the string used by make_query.
82 handler (callable): Invoked to handle incoming queries of this type
83
84 handler is invoked as:
85 result = handler(args)
86
87 where 'args' is a dict mapping strings to strings of the query
88 arguments. It should return a Deferred that will eventually yield an
89 object to encode as JSON.
90 """
91 if query_type in self.query_handlers:
92 raise KeyError(
93 "Already have a Query handler for %s" % (query_type,)
94 )
95
96 self.query_handlers[query_type] = handler
9768
9869 @defer.inlineCallbacks
9970 @log_function
170141 p["age_ts"] = request_time - int(p["age"])
171142 del p["age"]
172143
173 event = self.event_from_pdu_json(p)
144 event = event_from_pdu_json(p)
174145 room_id = event.room_id
175146 pdus_by_room.setdefault(room_id, []).append(event)
176147
228199 @defer.inlineCallbacks
229200 def received_edu(self, origin, edu_type, content):
230201 received_edus_counter.inc()
231
232 if edu_type in self.edu_handlers:
233 try:
234 yield self.edu_handlers[edu_type](origin, content)
235 except SynapseError as e:
236 logger.info("Failed to handle edu %r: %r", edu_type, e)
237 except Exception as e:
238 logger.exception("Failed to handle edu %r", edu_type)
239 else:
240 logger.warn("Received EDU of type %s with no handler", edu_type)
202 yield self.registry.on_edu(edu_type, origin, content)
241203
242204 @defer.inlineCallbacks
243205 @log_function
252214 result = self._state_resp_cache.get((room_id, event_id))
253215 if not result:
254216 with (yield self._server_linearizer.queue((origin, room_id))):
255 resp = yield self._state_resp_cache.set(
217 d = self._state_resp_cache.set(
256218 (room_id, event_id),
257 self._on_context_state_request_compute(room_id, event_id)
219 preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
258220 )
221 resp = yield make_deferred_yieldable(d)
259222 else:
260 resp = yield result
223 resp = yield make_deferred_yieldable(result)
261224
262225 defer.returnValue((200, resp))
263226
326289 @defer.inlineCallbacks
327290 def on_query_request(self, query_type, args):
328291 received_queries_counter.inc(query_type)
329
330 if query_type in self.query_handlers:
331 response = yield self.query_handlers[query_type](args)
332 defer.returnValue((200, response))
333 else:
334 defer.returnValue(
335 (404, "No handler for Query type '%s'" % (query_type,))
336 )
292 resp = yield self.registry.on_query(query_type, args)
293 defer.returnValue((200, resp))
337294
338295 @defer.inlineCallbacks
339296 def on_make_join_request(self, room_id, user_id):
343300
344301 @defer.inlineCallbacks
345302 def on_invite_request(self, origin, content):
346 pdu = self.event_from_pdu_json(content)
303 pdu = event_from_pdu_json(content)
347304 ret_pdu = yield self.handler.on_invite_request(origin, pdu)
348305 time_now = self._clock.time_msec()
349306 defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
351308 @defer.inlineCallbacks
352309 def on_send_join_request(self, origin, content):
353310 logger.debug("on_send_join_request: content: %s", content)
354 pdu = self.event_from_pdu_json(content)
311 pdu = event_from_pdu_json(content)
355312 logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
356313 res_pdus = yield self.handler.on_send_join_request(origin, pdu)
357314 time_now = self._clock.time_msec()
371328 @defer.inlineCallbacks
372329 def on_send_leave_request(self, origin, content):
373330 logger.debug("on_send_leave_request: content: %s", content)
374 pdu = self.event_from_pdu_json(content)
331 pdu = event_from_pdu_json(content)
375332 logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
376333 yield self.handler.on_send_leave_request(origin, pdu)
377334 defer.returnValue((200, {}))
408365 """
409366 with (yield self._server_linearizer.queue((origin, room_id))):
410367 auth_chain = [
411 self.event_from_pdu_json(e)
368 event_from_pdu_json(e)
412369 for e in content["auth_chain"]
413370 ]
414371
583540 def __str__(self):
584541 return "<ReplicationLayer(%s)>" % self.server_name
585542
586 def event_from_pdu_json(self, pdu_json, outlier=False):
587 event = FrozenEvent(
588 pdu_json
589 )
590
591 event.internal_metadata.outlier = outlier
592
593 return event
594
595543 @defer.inlineCallbacks
596544 def exchange_third_party_invite(
597545 self,
614562 origin, room_id, event_dict
615563 )
616564 defer.returnValue(ret)
565
566
567 class FederationHandlerRegistry(object):
568 """Allows classes to register themselves as handlers for a given EDU or
569 query type for incoming federation traffic.
570 """
571 def __init__(self):
572 self.edu_handlers = {}
573 self.query_handlers = {}
574
575 def register_edu_handler(self, edu_type, handler):
576 """Sets the handler callable that will be used to handle an incoming
577 federation EDU of the given type.
578
579 Args:
580 edu_type (str): The type of the incoming EDU to register handler for
581 handler (Callable[[str, dict]]): A callable invoked on incoming EDU
582 of the given type. The arguments are the origin server name and
583 the EDU contents.
584 """
585 if edu_type in self.edu_handlers:
586 raise KeyError("Already have an EDU handler for %s" % (edu_type,))
587
588 self.edu_handlers[edu_type] = handler
589
590 def register_query_handler(self, query_type, handler):
591 """Sets the handler callable that will be used to handle an incoming
592 federation query of the given type.
593
594 Args:
595 query_type (str): Category name of the query, which should match
596 the string used by make_query.
597 handler (Callable[[dict], Deferred[dict]]): Invoked to handle
598 incoming queries of this type. The return will be yielded
599 on and the result used as the response to the query request.
600 """
601 if query_type in self.query_handlers:
602 raise KeyError(
603 "Already have a Query handler for %s" % (query_type,)
604 )
605
606 self.query_handlers[query_type] = handler
607
608 @defer.inlineCallbacks
609 def on_edu(self, edu_type, origin, content):
610 handler = self.edu_handlers.get(edu_type)
611 if not handler:
612 logger.warn("No handler registered for EDU type %s", edu_type)
613
614 try:
615 yield handler(origin, content)
616 except SynapseError as e:
617 logger.info("Failed to handle edu %r: %r", edu_type, e)
618 except Exception as e:
619 logger.exception("Failed to handle edu %r", edu_type)
620
621 def on_query(self, query_type, args):
622 handler = self.query_handlers.get(query_type)
623 if not handler:
624 logger.warn("No handler registered for query type %s", query_type)
625 raise NotFoundError("No handler for Query type '%s'" % (query_type,))
626
627 return handler(args)
+0
-73
synapse/federation/replication.py less more
0 # -*- coding: utf-8 -*-
1 # Copyright 2014-2016 OpenMarket Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """This layer is responsible for replicating with remote home servers using
16 a given transport.
17 """
18
19 from .federation_client import FederationClient
20 from .federation_server import FederationServer
21
22 from .persistence import TransactionActions
23
24 import logging
25
26
27 logger = logging.getLogger(__name__)
28
29
30 class ReplicationLayer(FederationClient, FederationServer):
31 """This layer is responsible for replicating with remote home servers over
32 the given transport. I.e., does the sending and receiving of PDUs to
33 remote home servers.
34
35 The layer communicates with the rest of the server via a registered
36 ReplicationHandler.
37
38 In more detail, the layer:
39 * Receives incoming data and processes it into transactions and pdus.
40 * Fetches any PDUs it thinks it might have missed.
41 * Keeps the current state for contexts up to date by applying the
42 suitable conflict resolution.
43 * Sends outgoing pdus wrapped in transactions.
44 * Fills out the references to previous pdus/transactions appropriately
45 for outgoing data.
46 """
47
48 def __init__(self, hs, transport_layer):
49 self.server_name = hs.hostname
50
51 self.keyring = hs.get_keyring()
52
53 self.transport_layer = transport_layer
54
55 self.federation_client = self
56
57 self.store = hs.get_datastore()
58
59 self.handler = None
60 self.edu_handlers = {}
61 self.query_handlers = {}
62
63 self._clock = hs.get_clock()
64
65 self.transaction_actions = TransactionActions(self.store)
66
67 self.hs = hs
68
69 super(ReplicationLayer, self).__init__(hs)
70
71 def __str__(self):
72 return "<ReplicationLayer(%s)>" % self.server_name
1818 from .persistence import TransactionActions
1919 from .units import Transaction, Edu
2020
21 from synapse.api.errors import HttpResponseException
22 from synapse.util import logcontext
21 from synapse.api.errors import HttpResponseException, FederationDeniedError
22 from synapse.util import logcontext, PreserveLoggingContext
2323 from synapse.util.async import run_on_reactor
2424 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
2525 from synapse.util.metrics import measure_func
4040 sent_edus_counter = client_metrics.register_counter("sent_edus")
4141
4242 sent_transactions_counter = client_metrics.register_counter("sent_transactions")
43
44 events_processed_counter = client_metrics.register_counter("events_processed")
4345
4446
4547 class TransactionQueue(object):
145147 else:
146148 return not destination.startswith("localhost")
147149
148 @defer.inlineCallbacks
149150 def notify_new_events(self, current_id):
150151 """This gets called when we have some new events we might want to
151152 send out to other servers.
155156 if self._is_processing:
156157 return
157158
159 # fire off a processing loop in the background. It's likely it will
160 # outlast the current request, so run it in the sentinel logcontext.
161 with PreserveLoggingContext():
162 self._process_event_queue_loop()
163
164 @defer.inlineCallbacks
165 def _process_event_queue_loop(self):
158166 try:
159167 self._is_processing = True
160168 while True:
197205 logger.debug("Sending %s to %r", event, destinations)
198206
199207 self._send_pdu(event, destinations)
208
209 events_processed_counter.inc_by(len(events))
200210
201211 yield self.store.update_federation_out_pos(
202212 "events", next_token
479489 (e.retry_last_ts + e.retry_interval) / 1000.0
480490 ),
481491 )
492 except FederationDeniedError as e:
493 logger.info(e)
482494 except Exception as e:
483495 logger.warn(
484496 "TX [%s] Failed to send transaction: %s",
211211
212212 Fails with ``NotRetryingDestination`` if we are not yet ready
213213 to retry this server.
214
215 Fails with ``FederationDeniedError`` if the remote destination
216 is not in our federation whitelist
214217 """
215218 valid_memberships = {Membership.JOIN, Membership.LEAVE}
216219 if membership not in valid_memberships:
485488 )
486489
487490 @log_function
491 def update_group_profile(self, destination, group_id, requester_user_id, content):
492 """Update a remote group profile
493
494 Args:
495 destination (str)
496 group_id (str)
497 requester_user_id (str)
498 content (dict): The new profile of the group
499 """
500 path = PREFIX + "/groups/%s/profile" % (group_id,)
501
502 return self.client.post_json(
503 destination=destination,
504 path=path,
505 args={"requester_user_id": requester_user_id},
506 data=content,
507 ignore_backoff=True,
508 )
509
510 @log_function
488511 def get_group_summary(self, destination, group_id, requester_user_id):
489512 """Get a group summary
490513 """
515538 """Add a room to a group
516539 """
517540 path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
541
542 return self.client.post_json(
543 destination=destination,
544 path=path,
545 args={"requester_user_id": requester_user_id},
546 data=content,
547 ignore_backoff=True,
548 )
549
550 def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
551 config_key, content):
552 """Update room in group
553 """
554 path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,)
518555
519556 return self.client.post_json(
520557 destination=destination,
1515 from twisted.internet import defer
1616
1717 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
18 from synapse.api.errors import Codes, SynapseError
18 from synapse.api.errors import Codes, SynapseError, FederationDeniedError
1919 from synapse.http.server import JsonResource
2020 from synapse.http.servlet import (
2121 parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
8080 self.keyring = hs.get_keyring()
8181 self.server_name = hs.hostname
8282 self.store = hs.get_datastore()
83 self.federation_domain_whitelist = hs.config.federation_domain_whitelist
8384
8485 # A method just so we can pass 'self' as the authenticator to the Servlets
8586 @defer.inlineCallbacks
9192 "signatures": {},
9293 }
9394
95 if (
96 self.federation_domain_whitelist is not None and
97 self.server_name not in self.federation_domain_whitelist
98 ):
99 raise FederationDeniedError(self.server_name)
100
94101 if content is not None:
95102 json_request["content"] = content
96103
111118 key = strip_quotes(param_dict["key"])
112119 sig = strip_quotes(param_dict["sig"])
113120 return (origin, key, sig)
114 except:
121 except Exception:
115122 raise AuthenticationError(
116123 400, "Malformed Authorization header", Codes.UNAUTHORIZED
117124 )
176183 if self.REQUIRE_AUTH:
177184 logger.exception("authenticate_request failed")
178185 raise
179 except:
186 except Exception:
180187 logger.exception("authenticate_request failed")
181188 raise
182189
269276 code, response = yield self.handler.on_incoming_transaction(
270277 transaction_data
271278 )
272 except:
279 except Exception:
273280 logger.exception("on_incoming_transaction failed")
274281 raise
275282
609616
610617
611618 class FederationGroupsProfileServlet(BaseFederationServlet):
612 """Get the basic profile of a group on behalf of a user
619 """Get/set the basic profile of a group on behalf of a user
613620 """
614621 PATH = "/groups/(?P<group_id>[^/]*)/profile$"
615622
625632
626633 defer.returnValue((200, new_content))
627634
635 @defer.inlineCallbacks
636 def on_POST(self, origin, content, query, group_id):
637 requester_user_id = parse_string_from_args(query, "requester_user_id")
638 if get_domain_from_id(requester_user_id) != origin:
639 raise SynapseError(403, "requester_user_id doesn't match origin")
640
641 new_content = yield self.handler.update_group_profile(
642 group_id, requester_user_id, content
643 )
644
645 defer.returnValue((200, new_content))
646
628647
629648 class FederationGroupsSummaryServlet(BaseFederationServlet):
630649 PATH = "/groups/(?P<group_id>[^/]*)/summary$"
641660
642661 defer.returnValue((200, new_content))
643662
644 @defer.inlineCallbacks
645 def on_POST(self, origin, content, query, group_id):
646 requester_user_id = parse_string_from_args(query, "requester_user_id")
647 if get_domain_from_id(requester_user_id) != origin:
648 raise SynapseError(403, "requester_user_id doesn't match origin")
649
650 new_content = yield self.handler.update_group_profile(
651 group_id, requester_user_id, content
652 )
653
654 defer.returnValue((200, new_content))
655
656663
657664 class FederationGroupsRoomsServlet(BaseFederationServlet):
658665 """Get the rooms in a group on behalf of a user
675682 class FederationGroupsAddRoomsServlet(BaseFederationServlet):
676683 """Add/remove room from group
677684 """
678 PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
685 PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
679686
680687 @defer.inlineCallbacks
681688 def on_POST(self, origin, content, query, group_id, room_id):
700707 )
701708
702709 defer.returnValue((200, new_content))
710
711
712 class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
713 """Update room config in group
714 """
715 PATH = (
716 "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
717 "/config/(?P<config_key>[^/]*)$"
718 )
719
720 @defer.inlineCallbacks
721 def on_POST(self, origin, content, query, group_id, room_id, config_key):
722 requester_user_id = parse_string_from_args(query, "requester_user_id")
723 if get_domain_from_id(requester_user_id) != origin:
724 raise SynapseError(403, "requester_user_id doesn't match origin")
725
726 result = yield self.groups_handler.update_room_in_group(
727 group_id, requester_user_id, room_id, config_key, content,
728 )
729
730 defer.returnValue((200, result))
703731
704732
705733 class FederationGroupsUsersServlet(BaseFederationServlet):
11411169 FederationGroupsRolesServlet,
11421170 FederationGroupsRoleServlet,
11431171 FederationGroupsSummaryUsersServlet,
1172 FederationGroupsAddRoomsServlet,
1173 FederationGroupsAddRoomsConfigServlet,
11441174 )
11451175
11461176
11591189 def register_servlets(hs, resource, authenticator, ratelimiter):
11601190 for servletclass in FEDERATION_SERVLET_CLASSES:
11611191 servletclass(
1162 handler=hs.get_replication_layer(),
1192 handler=hs.get_federation_server(),
11631193 authenticator=authenticator,
11641194 ratelimiter=ratelimiter,
11651195 server_name=hs.hostname,
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 """Attestations ensure that users and groups can't lie about their memberships.
16
17 When a user joins a group the HS and GS swap attestations, which allow them
18 both to independently prove to third parties their membership.These
19 attestations have a validity period so need to be periodically renewed.
20
21 If a user leaves (or gets kicked out of) a group, either side can still use
22 their attestation to "prove" their membership, until the attestation expires.
23 Therefore attestations shouldn't be relied on to prove membership in important
24 cases, but can for less important situtations, e.g. showing a users membership
25 of groups on their profile, showing flairs, etc.abs
26
27 An attestsation is a signed blob of json that looks like:
28
29 {
30 "user_id": "@foo:a.example.com",
31 "group_id": "+bar:b.example.com",
32 "valid_until_ms": 1507994728530,
33 "signatures":{"matrix.org":{"ed25519:auto":"..."}}
34 }
35 """
36
37 import logging
38 import random
39
1540 from twisted.internet import defer
1641
1742 from synapse.api.errors import SynapseError
2146 from signedjson.sign import sign_json
2247
2348
49 logger = logging.getLogger(__name__)
50
51
2452 # Default validity duration for new attestations we create
2553 DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
54
55 # We add some jitter to the validity duration of attestations so that if we
56 # add lots of users at once we don't need to renew them all at once.
57 # The jitter is a multiplier picked randomly between the first and second number
58 DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
2659
2760 # Start trying to update our attestations when they come this close to expiring
2861 UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
72105 """Create an attestation for the group_id and user_id with default
73106 validity length.
74107 """
108 validity_period = DEFAULT_ATTESTATION_LENGTH_MS
109 validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
110 valid_until_ms = int(self.clock.time_msec() + validity_period)
111
75112 return sign_json({
76113 "group_id": group_id,
77114 "user_id": user_id,
78 "valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
115 "valid_until_ms": valid_until_ms,
79116 }, self.server_name, self.signing_key)
80117
81118
127164
128165 @defer.inlineCallbacks
129166 def _renew_attestation(group_id, user_id):
130 attestation = self.attestations.create_attestation(group_id, user_id)
131
132 if self.is_mine_id(group_id):
167 if not self.is_mine_id(group_id):
168 destination = get_domain_from_id(group_id)
169 elif not self.is_mine_id(user_id):
133170 destination = get_domain_from_id(user_id)
134171 else:
135 destination = get_domain_from_id(group_id)
172 logger.warn(
173 "Incorrectly trying to do attestations for user: %r in %r",
174 user_id, group_id,
175 )
176 yield self.store.remove_attestation_renewal(group_id, user_id)
177 return
178
179 attestation = self.attestations.create_attestation(group_id, user_id)
136180
137181 yield self.transport_client.renew_group_attestation(
138182 destination, group_id, user_id,
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import logging
16
17 from synapse.api.errors import SynapseError
18 from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
1519 from twisted.internet import defer
16
17 from synapse.api.errors import SynapseError
18 from synapse.types import UserID, get_domain_from_id, RoomID, GroupID
19
20
21 import logging
22 import urllib
2320
2421 logger = logging.getLogger(__name__)
2522
5148 hs.get_groups_attestation_renewer()
5249
5350 @defer.inlineCallbacks
54 def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
51 def check_group_is_ours(self, group_id, requester_user_id,
52 and_exists=False, and_is_admin=None):
5553 """Check that the group is ours, and optionally if it exists.
5654
5755 If group does exist then return group.
6967 if and_exists and not group:
7068 raise SynapseError(404, "Unknown group")
7169
70 is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
71 if group and not is_user_in_group and not group["is_public"]:
72 raise SynapseError(404, "Unknown group")
73
7274 if and_is_admin:
7375 is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
7476 if not is_admin:
8688
8789 A user/room may appear in multiple roles/categories.
8890 """
89 yield self.check_group_is_ours(group_id, and_exists=True)
91 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
9092
9193 is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
9294
155157 })
156158
157159 @defer.inlineCallbacks
158 def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
160 def update_group_summary_room(self, group_id, requester_user_id,
161 room_id, category_id, content):
159162 """Add/update a room to the group summary
160163 """
161 yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
164 yield self.check_group_is_ours(
165 group_id,
166 requester_user_id,
167 and_exists=True,
168 and_is_admin=requester_user_id,
169 )
162170
163171 RoomID.from_string(room_id) # Ensure valid room id
164172
177185 defer.returnValue({})
178186
179187 @defer.inlineCallbacks
180 def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
188 def delete_group_summary_room(self, group_id, requester_user_id,
189 room_id, category_id):
181190 """Remove a room from the summary
182191 """
183 yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
192 yield self.check_group_is_ours(
193 group_id,
194 requester_user_id,
195 and_exists=True,
196 and_is_admin=requester_user_id,
197 )
184198
185199 yield self.store.remove_room_from_summary(
186200 group_id=group_id,
191205 defer.returnValue({})
192206
193207 @defer.inlineCallbacks
194 def get_group_categories(self, group_id, user_id):
208 def get_group_categories(self, group_id, requester_user_id):
195209 """Get all categories in a group (as seen by user)
196210 """
197 yield self.check_group_is_ours(group_id, and_exists=True)
211 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
198212
199213 categories = yield self.store.get_group_categories(
200214 group_id=group_id,
202216 defer.returnValue({"categories": categories})
203217
204218 @defer.inlineCallbacks
205 def get_group_category(self, group_id, user_id, category_id):
219 def get_group_category(self, group_id, requester_user_id, category_id):
206220 """Get a specific category in a group (as seen by user)
207221 """
208 yield self.check_group_is_ours(group_id, and_exists=True)
222 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
209223
210224 res = yield self.store.get_group_category(
211225 group_id=group_id,
215229 defer.returnValue(res)
216230
217231 @defer.inlineCallbacks
218 def update_group_category(self, group_id, user_id, category_id, content):
232 def update_group_category(self, group_id, requester_user_id, category_id, content):
219233 """Add/Update a group category
220234 """
221 yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
235 yield self.check_group_is_ours(
236 group_id,
237 requester_user_id,
238 and_exists=True,
239 and_is_admin=requester_user_id,
240 )
222241
223242 is_public = _parse_visibility_from_contents(content)
224243 profile = content.get("profile")
233252 defer.returnValue({})
234253
235254 @defer.inlineCallbacks
236 def delete_group_category(self, group_id, user_id, category_id):
255 def delete_group_category(self, group_id, requester_user_id, category_id):
237256 """Delete a group category
238257 """
239 yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
258 yield self.check_group_is_ours(
259 group_id,
260 requester_user_id,
261 and_exists=True,
262 and_is_admin=requester_user_id
263 )
240264
241265 yield self.store.remove_group_category(
242266 group_id=group_id,
246270 defer.returnValue({})
247271
248272 @defer.inlineCallbacks
249 def get_group_roles(self, group_id, user_id):
273 def get_group_roles(self, group_id, requester_user_id):
250274 """Get all roles in a group (as seen by user)
251275 """
252 yield self.check_group_is_ours(group_id, and_exists=True)
276 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
253277
254278 roles = yield self.store.get_group_roles(
255279 group_id=group_id,
257281 defer.returnValue({"roles": roles})
258282
259283 @defer.inlineCallbacks
260 def get_group_role(self, group_id, user_id, role_id):
284 def get_group_role(self, group_id, requester_user_id, role_id):
261285 """Get a specific role in a group (as seen by user)
262286 """
263 yield self.check_group_is_ours(group_id, and_exists=True)
287 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
264288
265289 res = yield self.store.get_group_role(
266290 group_id=group_id,
269293 defer.returnValue(res)
270294
271295 @defer.inlineCallbacks
272 def update_group_role(self, group_id, user_id, role_id, content):
296 def update_group_role(self, group_id, requester_user_id, role_id, content):
273297 """Add/update a role in a group
274298 """
275 yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
299 yield self.check_group_is_ours(
300 group_id,
301 requester_user_id,
302 and_exists=True,
303 and_is_admin=requester_user_id,
304 )
276305
277306 is_public = _parse_visibility_from_contents(content)
278307
288317 defer.returnValue({})
289318
290319 @defer.inlineCallbacks
291 def delete_group_role(self, group_id, user_id, role_id):
320 def delete_group_role(self, group_id, requester_user_id, role_id):
292321 """Remove role from group
293322 """
294 yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
323 yield self.check_group_is_ours(
324 group_id,
325 requester_user_id,
326 and_exists=True,
327 and_is_admin=requester_user_id,
328 )
295329
296330 yield self.store.remove_group_role(
297331 group_id=group_id,
306340 """Add/update a users entry in the group summary
307341 """
308342 yield self.check_group_is_ours(
309 group_id, and_exists=True, and_is_admin=requester_user_id,
343 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
310344 )
311345
312346 order = content.get("order", None)
328362 """Remove a user from the group summary
329363 """
330364 yield self.check_group_is_ours(
331 group_id, and_exists=True, and_is_admin=requester_user_id,
365 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
332366 )
333367
334368 yield self.store.remove_user_from_summary(
344378 """Get the group profile as seen by requester_user_id
345379 """
346380
347 yield self.check_group_is_ours(group_id)
381 yield self.check_group_is_ours(group_id, requester_user_id)
348382
349383 group_description = yield self.store.get_group(group_id)
350384
358392 """Update the group profile
359393 """
360394 yield self.check_group_is_ours(
361 group_id, and_exists=True, and_is_admin=requester_user_id,
395 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
362396 )
363397
364398 profile = {}
379413 The ordering is arbitrary at the moment
380414 """
381415
382 yield self.check_group_is_ours(group_id, and_exists=True)
416 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
383417
384418 is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
385419
391425 for user_result in user_results:
392426 g_user_id = user_result["user_id"]
393427 is_public = user_result["is_public"]
428 is_privileged = user_result["is_admin"]
394429
395430 entry = {"user_id": g_user_id}
396431
397432 profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
398433 entry.update(profile)
399434
400 if not is_public:
401 entry["is_public"] = False
435 entry["is_public"] = bool(is_public)
436 entry["is_privileged"] = bool(is_privileged)
402437
403438 if not self.is_mine_id(g_user_id):
404439 attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
427462 The ordering is arbitrary at the moment
428463 """
429464
430 yield self.check_group_is_ours(group_id, and_exists=True)
465 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
431466
432467 is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
433468
461496 This returns rooms in order of decreasing number of joined users
462497 """
463498
464 yield self.check_group_is_ours(group_id, and_exists=True)
499 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
465500
466501 is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
467502
472507 chunk = []
473508 for room_result in room_results:
474509 room_id = room_result["room_id"]
475 is_public = room_result["is_public"]
476510
477511 joined_users = yield self.store.get_users_in_room(room_id)
478512 entry = yield self.room_list_handler.generate_room_entry(
483517 if not entry:
484518 continue
485519
486 if not is_public:
487 entry["is_public"] = False
520 entry["is_public"] = bool(room_result["is_public"])
488521
489522 chunk.append(entry)
490523
502535 RoomID.from_string(room_id) # Ensure valid room id
503536
504537 yield self.check_group_is_ours(
505 group_id, and_exists=True, and_is_admin=requester_user_id
538 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
506539 )
507540
508541 is_public = _parse_visibility_from_contents(content)
509542
510543 yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
544
545 defer.returnValue({})
546
547 @defer.inlineCallbacks
548 def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
549 content):
550 """Update room in group
551 """
552 RoomID.from_string(room_id) # Ensure valid room id
553
554 yield self.check_group_is_ours(
555 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
556 )
557
558 if config_key == "m.visibility":
559 is_public = _parse_visibility_dict(content)
560
561 yield self.store.update_room_in_group_visibility(
562 group_id, room_id,
563 is_public=is_public,
564 )
565 else:
566 raise SynapseError(400, "Uknown config option")
511567
512568 defer.returnValue({})
513569
516572 """Remove room from group
517573 """
518574 yield self.check_group_is_ours(
519 group_id, and_exists=True, and_is_admin=requester_user_id
575 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
520576 )
521577
522578 yield self.store.remove_room_from_group(group_id, room_id)
529585 """
530586
531587 group = yield self.check_group_is_ours(
532 group_id, and_exists=True, and_is_admin=requester_user_id
588 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
533589 )
534590
535591 # TODO: Check if user knocked
598654 raise SynapseError(502, "Unknown state returned by HS")
599655
600656 @defer.inlineCallbacks
601 def accept_invite(self, group_id, user_id, content):
657 def accept_invite(self, group_id, requester_user_id, content):
602658 """User tries to accept an invite to the group.
603659
604660 This is different from them asking to join, and so should error if no
605661 invite exists (and they're not a member of the group)
606662 """
607663
608 yield self.check_group_is_ours(group_id, and_exists=True)
609
610 if not self.store.is_user_invited_to_local_group(group_id, user_id):
664 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
665
666 is_invited = yield self.store.is_user_invited_to_local_group(
667 group_id, requester_user_id,
668 )
669 if not is_invited:
611670 raise SynapseError(403, "User not invited to group")
612671
613 if not self.hs.is_mine_id(user_id):
672 if not self.hs.is_mine_id(requester_user_id):
673 local_attestation = self.attestations.create_attestation(
674 group_id, requester_user_id,
675 )
614676 remote_attestation = content["attestation"]
615677
616678 yield self.attestations.verify_attestation(
617679 remote_attestation,
618 user_id=user_id,
680 user_id=requester_user_id,
619681 group_id=group_id,
620682 )
621683 else:
684 local_attestation = None
622685 remote_attestation = None
623686
624 local_attestation = self.attestations.create_attestation(group_id, user_id)
625
626687 is_public = _parse_visibility_from_contents(content)
627688
628689 yield self.store.add_user_to_group(
629 group_id, user_id,
690 group_id, requester_user_id,
630691 is_admin=False,
631692 is_public=is_public,
632693 local_attestation=local_attestation,
639700 })
640701
641702 @defer.inlineCallbacks
642 def knock(self, group_id, user_id, content):
703 def knock(self, group_id, requester_user_id, content):
643704 """A user requests becoming a member of the group
644705 """
645 yield self.check_group_is_ours(group_id, and_exists=True)
706 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
646707
647708 raise NotImplementedError()
648709
649710 @defer.inlineCallbacks
650 def accept_knock(self, group_id, user_id, content):
711 def accept_knock(self, group_id, requester_user_id, content):
651712 """Accept a users knock to the room.
652713
653714 Errors if the user hasn't knocked, rather than inviting them.
654715 """
655716
656 yield self.check_group_is_ours(group_id, and_exists=True)
717 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
657718
658719 raise NotImplementedError()
659720
660721 @defer.inlineCallbacks
661722 def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
662 """Remove a user from the group; either a user is leaving or and admin
663 kicked htem.
664 """
665
666 yield self.check_group_is_ours(group_id, and_exists=True)
723 """Remove a user from the group; either a user is leaving or an admin
724 kicked them.
725 """
726
727 yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
667728
668729 is_kick = False
669730 if requester_user_id != user_id:
694755 defer.returnValue({})
695756
696757 @defer.inlineCallbacks
697 def create_group(self, group_id, user_id, content):
698 group = yield self.check_group_is_ours(group_id)
699
700 _validate_group_id(group_id)
758 def create_group(self, group_id, requester_user_id, content):
759 group = yield self.check_group_is_ours(group_id, requester_user_id)
701760
702761 logger.info("Attempting to create group with ID: %r", group_id)
762
763 # parsing the id into a GroupID validates it.
764 group_id_obj = GroupID.from_string(group_id)
765
703766 if group:
704767 raise SynapseError(400, "Group already exists")
705768
706 is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
769 is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
707770 if not is_admin:
708771 if not self.hs.config.enable_group_creation:
709772 raise SynapseError(
710 403, "Only server admin can create group on this server",
773 403, "Only a server admin can create groups on this server",
711774 )
712 localpart = GroupID.from_string(group_id).localpart
775 localpart = group_id_obj.localpart
713776 if not localpart.startswith(self.hs.config.group_creation_prefix):
714777 raise SynapseError(
715778 400,
727790
728791 yield self.store.create_group(
729792 group_id,
730 user_id,
793 requester_user_id,
731794 name=name,
732795 avatar_url=avatar_url,
733796 short_description=short_description,
734797 long_description=long_description,
735798 )
736799
737 if not self.hs.is_mine_id(user_id):
800 if not self.hs.is_mine_id(requester_user_id):
738801 remote_attestation = content["attestation"]
739802
740803 yield self.attestations.verify_attestation(
741804 remote_attestation,
742 user_id=user_id,
805 user_id=requester_user_id,
743806 group_id=group_id,
744807 )
745808
746 local_attestation = self.attestations.create_attestation(group_id, user_id)
809 local_attestation = self.attestations.create_attestation(
810 group_id,
811 requester_user_id,
812 )
747813 else:
748814 local_attestation = None
749815 remote_attestation = None
750816
751817 yield self.store.add_user_to_group(
752 group_id, user_id,
818 group_id, requester_user_id,
753819 is_admin=True,
754820 is_public=True, # TODO
755821 local_attestation=local_attestation,
756822 remote_attestation=remote_attestation,
757823 )
758824
759 if not self.hs.is_mine_id(user_id):
825 if not self.hs.is_mine_id(requester_user_id):
760826 yield self.store.add_remote_profile_cache(
761 user_id,
827 requester_user_id,
762828 displayname=user_profile.get("displayname"),
763829 avatar_url=user_profile.get("avatar_url"),
764830 )
773839 public or not
774840 """
775841
776 visibility = content.get("visibility")
842 visibility = content.get("m.visibility")
777843 if visibility:
778 vis_type = visibility["type"]
779 if vis_type not in ("public", "private"):
780 raise SynapseError(
781 400, "Synapse only supports 'public'/'private' visibility"
782 )
783 is_public = vis_type == "public"
844 return _parse_visibility_dict(visibility)
784845 else:
785846 is_public = True
786847
787848 return is_public
788849
789850
790 def _validate_group_id(group_id):
791 """Validates the group ID is valid for creation on this home server
851 def _parse_visibility_dict(visibility):
852 """Given a dict for the "m.visibility" config return if the entity should
853 be public or not
792854 """
793 localpart = GroupID.from_string(group_id).localpart
794
795 if localpart.lower() != localpart:
796 raise SynapseError(400, "Group ID must be lower case")
797
798 if urllib.quote(localpart.encode('utf-8')) != localpart:
855 vis_type = visibility.get("type")
856 if not vis_type:
857 return True
858
859 if vis_type not in ("public", "private"):
799860 raise SynapseError(
800 400,
801 "Group ID can only contain characters a-z, 0-9, or '_-./'",
802 )
861 400, "Synapse only supports 'public'/'private' visibility"
862 )
863 return vis_type == "public"
1616 from .room import (
1717 RoomCreationHandler, RoomContextHandler,
1818 )
19 from .room_member import RoomMemberHandler
2019 from .message import MessageHandler
2120 from .federation import FederationHandler
2221 from .directory import DirectoryHandler
4847 self.registration_handler = RegistrationHandler(hs)
4948 self.message_handler = MessageHandler(hs)
5049 self.room_creation_handler = RoomCreationHandler(hs)
51 self.room_member_handler = RoomMemberHandler(hs)
5250 self.federation_handler = FederationHandler(hs)
5351 self.directory_handler = DirectoryHandler(hs)
5452 self.admin_handler = AdminHandler(hs)
157157 # homeserver.
158158 requester = synapse.types.create_requester(
159159 target_user, is_guest=True)
160 handler = self.hs.get_handlers().room_member_handler
160 handler = self.hs.get_room_member_handler()
161161 yield handler.update_membership(
162162 requester,
163163 target_user,
1414
1515 from twisted.internet import defer
1616
17 import synapse
1718 from synapse.api.constants import EventTypes
1819 from synapse.util.metrics import Measure
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
20 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
2021
2122 import logging
2223
2324 logger = logging.getLogger(__name__)
25
26 metrics = synapse.metrics.get_metrics_for(__name__)
27
28 events_processed_counter = metrics.register_counter("events_processed")
2429
2530
2631 def log_failure(failure):
6974 with Measure(self.clock, "notify_interested_services"):
7075 self.is_processing = True
7176 try:
72 upper_bound = self.current_max
7377 limit = 100
7478 while True:
7579 upper_bound, events = yield self.store.get_new_events_for_appservice(
76 upper_bound, limit
80 self.current_max, limit
7781 )
7882
7983 if not events:
103107 service, event
104108 )
105109
110 events_processed_counter.inc_by(len(events))
111
106112 yield self.store.set_appservice_last_pos(upper_bound)
107
108 if len(events) < limit:
109 break
110113 finally:
111114 self.is_processing = False
112115
162165 def query_3pe(self, kind, protocol, fields):
163166 services = yield self._get_services_for_3pn(protocol)
164167
165 results = yield preserve_context_over_deferred(defer.DeferredList([
168 results = yield make_deferred_yieldable(defer.DeferredList([
166169 preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
167170 for service in services
168171 ], consumeErrors=True))
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15
16 from twisted.internet import defer
15 from twisted.internet import defer, threads
1716
1817 from ._base import BaseHandler
1918 from synapse.api.constants import LoginType
19 from synapse.api.errors import (
20 AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
21 SynapseError,
22 )
23 from synapse.module_api import ModuleApi
2024 from synapse.types import UserID
21 from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
2225 from synapse.util.async import run_on_reactor
2326 from synapse.util.caches.expiringcache import ExpiringCache
27 from synapse.util.logcontext import make_deferred_yieldable
2428
2529 from twisted.web.client import PartialDownloadError
2630
4549 """
4650 super(AuthHandler, self).__init__(hs)
4751 self.checkers = {
48 LoginType.PASSWORD: self._check_password_auth,
4952 LoginType.RECAPTCHA: self._check_recaptcha,
5053 LoginType.EMAIL_IDENTITY: self._check_email_identity,
5154 LoginType.MSISDN: self._check_msisdn,
6265 reset_expiry_on_get=True,
6366 )
6467
65 account_handler = _AccountHandler(
66 hs, check_user_exists=self.check_user_exists
67 )
68
68 account_handler = ModuleApi(hs, self)
6969 self.password_providers = [
7070 module(config=config, account_handler=account_handler)
7171 for module, config in hs.config.password_providers
7474 logger.info("Extra password_providers: %r", self.password_providers)
7575
7676 self.hs = hs # FIXME better possibility to access registrationHandler later?
77 self.device_handler = hs.get_device_handler()
7877 self.macaroon_gen = hs.get_macaroon_generator()
78 self._password_enabled = hs.config.password_enabled
79
80 # we keep this as a list despite the O(N^2) implication so that we can
81 # keep PASSWORD first and avoid confusing clients which pick the first
82 # type in the list. (NB that the spec doesn't require us to do so and
83 # clients which favour types that they don't understand over those that
84 # they do are technically broken)
85 login_types = []
86 if self._password_enabled:
87 login_types.append(LoginType.PASSWORD)
88 for provider in self.password_providers:
89 if hasattr(provider, "get_supported_login_types"):
90 for t in provider.get_supported_login_types().keys():
91 if t not in login_types:
92 login_types.append(t)
93 self._supported_login_types = login_types
94
95 @defer.inlineCallbacks
96 def validate_user_via_ui_auth(self, requester, request_body, clientip):
97 """
98 Checks that the user is who they claim to be, via a UI auth.
99
100 This is used for things like device deletion and password reset where
101 the user already has a valid access token, but we want to double-check
102 that it isn't stolen by re-authenticating them.
103
104 Args:
105 requester (Requester): The user, as given by the access token
106
107 request_body (dict): The body of the request sent by the client
108
109 clientip (str): The IP address of the client.
110
111 Returns:
112 defer.Deferred[dict]: the parameters for this request (which may
113 have been given only in a previous call).
114
115 Raises:
116 InteractiveAuthIncompleteError if the client has not yet completed
117 any of the permitted login flows
118
119 AuthError if the client has completed a login flow, and it gives
120 a different user to `requester`
121 """
122
123 # build a list of supported flows
124 flows = [
125 [login_type] for login_type in self._supported_login_types
126 ]
127
128 result, params, _ = yield self.check_auth(
129 flows, request_body, clientip,
130 )
131
132 # find the completed login type
133 for login_type in self._supported_login_types:
134 if login_type not in result:
135 continue
136
137 user_id = result[login_type]
138 break
139 else:
140 # this can't happen
141 raise Exception(
142 "check_auth returned True but no successful login type",
143 )
144
145 # check that the UI auth matched the access token
146 if user_id != requester.user.to_string():
147 raise AuthError(403, "Invalid auth")
148
149 defer.returnValue(params)
79150
80151 @defer.inlineCallbacks
81152 def check_auth(self, flows, clientdict, clientip):
82153 """
83154 Takes a dictionary sent by the client in the login / registration
84 protocol and handles the login flow.
155 protocol and handles the User-Interactive Auth flow.
85156
86157 As a side effect, this function fills in the 'creds' key on the user's
87158 session with a map, which maps each auth-type (str) to the relevant
88159 identity authenticated by that auth-type (mostly str, but for captcha, bool).
89160
161 If no auth flows have been completed successfully, raises an
162 InteractiveAuthIncompleteError. To handle this, you can use
163 synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
164 decorator.
165
90166 Args:
91167 flows (list): A list of login flows. Each flow is an ordered list of
92168 strings representing auth-types. At least one full
93169 flow must be completed in order for auth to be successful.
170
94171 clientdict: The dictionary from the client root level, not the
95172 'auth' key: this method prompts for auth if none is sent.
173
96174 clientip (str): The IP address of the client.
97 Returns:
98 A tuple of (authed, dict, dict, session_id) where authed is true if
99 the client has successfully completed an auth flow. If it is true
100 the first dict contains the authenticated credentials of each stage.
101
102 If authed is false, the first dictionary is the server response to
103 the login request and should be passed back to the client.
104
105 In either case, the second dict contains the parameters for this
106 request (which may have been given only in a previous call).
107
108 session_id is the ID of this session, either passed in by the client
109 or assigned by the call to check_auth
175
176 Returns:
177 defer.Deferred[dict, dict, str]: a deferred tuple of
178 (creds, params, session_id).
179
180 'creds' contains the authenticated credentials of each stage.
181
182 'params' contains the parameters for this request (which may
183 have been given only in a previous call).
184
185 'session_id' is the ID of this session, either passed in by the
186 client or assigned by this call
187
188 Raises:
189 InteractiveAuthIncompleteError if the client has not yet completed
190 all the stages in any of the permitted flows.
110191 """
111192
112193 authdict = None
134215 clientdict = session['clientdict']
135216
136217 if not authdict:
137 defer.returnValue(
138 (
139 False, self._auth_dict_for_flows(flows, session),
140 clientdict, session['id']
141 )
218 raise InteractiveAuthIncompleteError(
219 self._auth_dict_for_flows(flows, session),
142220 )
143221
144222 if 'creds' not in session:
149227 errordict = {}
150228 if 'type' in authdict:
151229 login_type = authdict['type']
152 if login_type not in self.checkers:
153 raise LoginError(400, "", Codes.UNRECOGNIZED)
154230 try:
155 result = yield self.checkers[login_type](authdict, clientip)
231 result = yield self._check_auth_dict(authdict, clientip)
156232 if result:
157233 creds[login_type] = result
158234 self._save_session(session)
159 except LoginError, e:
235 except LoginError as e:
160236 if login_type == LoginType.EMAIL_IDENTITY:
161237 # riot used to have a bug where it would request a new
162238 # validation token (thus sending a new email) each time it
165241 #
166242 # Grandfather in the old behaviour for now to avoid
167243 # breaking old riot deployments.
168 raise e
244 raise
169245
170246 # this step failed. Merge the error dict into the response
171247 # so that the client can have another go.
182258 "Auth completed with creds: %r. Client dict has keys: %r",
183259 creds, clientdict.keys()
184260 )
185 defer.returnValue((True, creds, clientdict, session['id']))
261 defer.returnValue((creds, clientdict, session['id']))
186262
187263 ret = self._auth_dict_for_flows(flows, session)
188264 ret['completed'] = creds.keys()
189265 ret.update(errordict)
190 defer.returnValue((False, ret, clientdict, session['id']))
266 raise InteractiveAuthIncompleteError(
267 ret,
268 )
191269
192270 @defer.inlineCallbacks
193271 def add_oob_auth(self, stagetype, authdict, clientip):
259337 sess = self._get_session_info(session_id)
260338 return sess.setdefault('serverdict', {}).get(key, default)
261339
262 def _check_password_auth(self, authdict, _):
263 if "user" not in authdict or "password" not in authdict:
264 raise LoginError(400, "", Codes.MISSING_PARAM)
265
266 user_id = authdict["user"]
267 password = authdict["password"]
268 if not user_id.startswith('@'):
269 user_id = UserID.create(user_id, self.hs.hostname).to_string()
270
271 return self._check_password(user_id, password)
340 @defer.inlineCallbacks
341 def _check_auth_dict(self, authdict, clientip):
342 """Attempt to validate the auth dict provided by a client
343
344 Args:
345 authdict (object): auth dict provided by the client
346 clientip (str): IP address of the client
347
348 Returns:
349 Deferred: result of the stage verification.
350
351 Raises:
352 StoreError if there was a problem accessing the database
353 SynapseError if there was a problem with the request
354 LoginError if there was an authentication problem.
355 """
356 login_type = authdict['type']
357 checker = self.checkers.get(login_type)
358 if checker is not None:
359 res = yield checker(authdict, clientip)
360 defer.returnValue(res)
361
362 # build a v1-login-style dict out of the authdict and fall back to the
363 # v1 code
364 user_id = authdict.get("user")
365
366 if user_id is None:
367 raise SynapseError(400, "", Codes.MISSING_PARAM)
368
369 (canonical_id, callback) = yield self.validate_login(user_id, authdict)
370 defer.returnValue(canonical_id)
272371
273372 @defer.inlineCallbacks
274373 def _check_recaptcha(self, authdict, clientip):
397496
398497 return self.sessions[session_id]
399498
400 def validate_password_login(self, user_id, password):
401 """
402 Authenticates the user with their username and password.
403
404 Used only by the v1 login API.
405
406 Args:
407 user_id (str): complete @user:id
408 password (str): Password
409 Returns:
410 defer.Deferred: (str) canonical user id
411 Raises:
412 StoreError if there was a problem accessing the database
413 LoginError if there was an authentication problem.
414 """
415 return self._check_password(user_id, password)
416
417 @defer.inlineCallbacks
418 def get_access_token_for_user_id(self, user_id, device_id=None,
419 initial_display_name=None):
499 @defer.inlineCallbacks
500 def get_access_token_for_user_id(self, user_id, device_id=None):
420501 """
421502 Creates a new access token for the user with the given user ID.
422503
430511 device_id (str|None): the device ID to associate with the tokens.
431512 None to leave the tokens unassociated with a device (deprecated:
432513 we should always have a device ID)
433 initial_display_name (str): display name to associate with the
434 device if it needs re-registering
435514 Returns:
436515 The access token for the user's session.
437516 Raises:
438517 StoreError if there was a problem storing the token.
439 LoginError if there was an authentication problem.
440518 """
441519 logger.info("Logging in user %s on device %s", user_id, device_id)
442520 access_token = yield self.issue_access_token(user_id, device_id)
446524 # really don't want is active access_tokens without a record of the
447525 # device, so we double-check it here.
448526 if device_id is not None:
449 yield self.device_handler.check_device_registered(
450 user_id, device_id, initial_display_name
451 )
527 try:
528 yield self.store.get_device(user_id, device_id)
529 except StoreError:
530 yield self.store.delete_access_token(access_token)
531 raise StoreError(400, "Login raced against device deletion")
452532
453533 defer.returnValue(access_token)
454534
500580 )
501581 defer.returnValue(result)
502582
503 @defer.inlineCallbacks
504 def _check_password(self, user_id, password):
505 """Authenticate a user against the LDAP and local databases.
506
507 user_id is checked case insensitively against the local database, but
508 will throw if there are multiple inexact matches.
509
510 Args:
511 user_id (str): complete @user:id
512 Returns:
513 (str) the canonical_user_id
583 def get_supported_login_types(self):
584 """Get a the login types supported for the /login API
585
586 By default this is just 'm.login.password' (unless password_enabled is
587 False in the config file), but password auth providers can provide
588 other login types.
589
590 Returns:
591 Iterable[str]: login types
592 """
593 return self._supported_login_types
594
595 @defer.inlineCallbacks
596 def validate_login(self, username, login_submission):
597 """Authenticates the user for the /login API
598
599 Also used by the user-interactive auth flow to validate
600 m.login.password auth types.
601
602 Args:
603 username (str): username supplied by the user
604 login_submission (dict): the whole of the login submission
605 (including 'type' and other relevant fields)
606 Returns:
607 Deferred[str, func]: canonical user id, and optional callback
608 to be called once the access token and device id are issued
514609 Raises:
515 LoginError if login fails
516 """
610 StoreError if there was a problem accessing the database
611 SynapseError if there was a problem with the request
612 LoginError if there was an authentication problem.
613 """
614
615 if username.startswith('@'):
616 qualified_user_id = username
617 else:
618 qualified_user_id = UserID(
619 username, self.hs.hostname
620 ).to_string()
621
622 login_type = login_submission.get("type")
623 known_login_type = False
624
625 # special case to check for "password" for the check_password interface
626 # for the auth providers
627 password = login_submission.get("password")
628 if login_type == LoginType.PASSWORD:
629 if not self._password_enabled:
630 raise SynapseError(400, "Password login has been disabled.")
631 if not password:
632 raise SynapseError(400, "Missing parameter: password")
633
517634 for provider in self.password_providers:
518 is_valid = yield provider.check_password(user_id, password)
519 if is_valid:
520 defer.returnValue(user_id)
521
522 canonical_user_id = yield self._check_local_password(user_id, password)
523
524 if canonical_user_id:
525 defer.returnValue(canonical_user_id)
635 if (hasattr(provider, "check_password")
636 and login_type == LoginType.PASSWORD):
637 known_login_type = True
638 is_valid = yield provider.check_password(
639 qualified_user_id, password,
640 )
641 if is_valid:
642 defer.returnValue((qualified_user_id, None))
643
644 if (not hasattr(provider, "get_supported_login_types")
645 or not hasattr(provider, "check_auth")):
646 # this password provider doesn't understand custom login types
647 continue
648
649 supported_login_types = provider.get_supported_login_types()
650 if login_type not in supported_login_types:
651 # this password provider doesn't understand this login type
652 continue
653
654 known_login_type = True
655 login_fields = supported_login_types[login_type]
656
657 missing_fields = []
658 login_dict = {}
659 for f in login_fields:
660 if f not in login_submission:
661 missing_fields.append(f)
662 else:
663 login_dict[f] = login_submission[f]
664 if missing_fields:
665 raise SynapseError(
666 400, "Missing parameters for login type %s: %s" % (
667 login_type,
668 missing_fields,
669 ),
670 )
671
672 result = yield provider.check_auth(
673 username, login_type, login_dict,
674 )
675 if result:
676 if isinstance(result, str):
677 result = (result, None)
678 defer.returnValue(result)
679
680 if login_type == LoginType.PASSWORD:
681 known_login_type = True
682
683 canonical_user_id = yield self._check_local_password(
684 qualified_user_id, password,
685 )
686
687 if canonical_user_id:
688 defer.returnValue((canonical_user_id, None))
689
690 if not known_login_type:
691 raise SynapseError(400, "Unknown login type %s" % login_type)
526692
527693 # unknown username or invalid password. We raise a 403 here, but note
528694 # that if we're doing user-interactive login, it turns all LoginErrors
548714 if not lookupres:
549715 defer.returnValue(None)
550716 (user_id, password_hash) = lookupres
551 result = self.validate_hash(password, password_hash)
717 result = yield self.validate_hash(password, password_hash)
552718 if not result:
553719 logger.warn("Failed password login for user %s", user_id)
554720 defer.returnValue(None)
572738 raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
573739
574740 @defer.inlineCallbacks
575 def set_password(self, user_id, newpassword, requester=None):
576 password_hash = self.hash(newpassword)
577
578 except_access_token_id = requester.access_token_id if requester else None
579
580 try:
581 yield self.store.user_set_password_hash(user_id, password_hash)
582 except StoreError as e:
583 if e.code == 404:
584 raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
585 raise e
586 yield self.store.user_delete_access_tokens(
587 user_id, except_access_token_id
741 def delete_access_token(self, access_token):
742 """Invalidate a single access token
743
744 Args:
745 access_token (str): access token to be deleted
746
747 Returns:
748 Deferred
749 """
750 user_info = yield self.auth.get_user_by_access_token(access_token)
751 yield self.store.delete_access_token(access_token)
752
753 # see if any of our auth providers want to know about this
754 for provider in self.password_providers:
755 if hasattr(provider, "on_logged_out"):
756 yield provider.on_logged_out(
757 user_id=str(user_info["user"]),
758 device_id=user_info["device_id"],
759 access_token=access_token,
760 )
761
762 # delete pushers associated with this access token
763 if user_info["token_id"] is not None:
764 yield self.hs.get_pusherpool().remove_pushers_by_access_token(
765 str(user_info["user"]), (user_info["token_id"], )
766 )
767
768 @defer.inlineCallbacks
769 def delete_access_tokens_for_user(self, user_id, except_token_id=None,
770 device_id=None):
771 """Invalidate access tokens belonging to a user
772
773 Args:
774 user_id (str): ID of user the tokens belong to
775 except_token_id (str|None): access_token ID which should *not* be
776 deleted
777 device_id (str|None): ID of device the tokens are associated with.
778 If None, tokens associated with any device (or no device) will
779 be deleted
780 Returns:
781 Deferred
782 """
783 tokens_and_devices = yield self.store.user_delete_access_tokens(
784 user_id, except_token_id=except_token_id, device_id=device_id,
588785 )
589 yield self.hs.get_pusherpool().remove_pushers_by_user(
590 user_id, except_access_token_id
786
787 # see if any of our auth providers want to know about this
788 for provider in self.password_providers:
789 if hasattr(provider, "on_logged_out"):
790 for token, token_id, device_id in tokens_and_devices:
791 yield provider.on_logged_out(
792 user_id=user_id,
793 device_id=device_id,
794 access_token=token,
795 )
796
797 # delete pushers associated with the access tokens
798 yield self.hs.get_pusherpool().remove_pushers_by_access_token(
799 user_id, (token_id for _, token_id, _ in tokens_and_devices),
591800 )
592801
593802 @defer.inlineCallbacks
633842 password (str): Password to hash.
634843
635844 Returns:
636 Hashed password (str).
637 """
638 return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
639 bcrypt.gensalt(self.bcrypt_rounds))
845 Deferred(str): Hashed password.
846 """
847 def _do_hash():
848 return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
849 bcrypt.gensalt(self.bcrypt_rounds))
850
851 return make_deferred_yieldable(threads.deferToThread(_do_hash))
640852
641853 def validate_hash(self, password, stored_hash):
642854 """Validates that self.hash(password) == stored_hash.
646858 stored_hash (str): Expected hash value.
647859
648860 Returns:
649 Whether self.hash(password) == stored_hash (bool).
650 """
861 Deferred(bool): Whether self.hash(password) == stored_hash.
862 """
863
864 def _do_validate_hash():
865 return bcrypt.checkpw(
866 password.encode('utf8') + self.hs.config.password_pepper,
867 stored_hash.encode('utf8')
868 )
869
651870 if stored_hash:
652 return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
653 stored_hash.encode('utf8')) == stored_hash
871 return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
654872 else:
655 return False
873 return defer.succeed(False)
656874
657875
658876 class MacaroonGeneartor(object):
695913 macaroon.add_first_party_caveat("gen = 1")
696914 macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
697915 return macaroon
698
699
700 class _AccountHandler(object):
701 """A proxy object that gets passed to password auth providers so they
702 can register new users etc if necessary.
703 """
704 def __init__(self, hs, check_user_exists):
705 self.hs = hs
706
707 self._check_user_exists = check_user_exists
708
709 def check_user_exists(self, user_id):
710 """Check if user exissts.
711
712 Returns:
713 Deferred(bool)
714 """
715 return self._check_user_exists(user_id)
716
717 def register(self, localpart):
718 """Registers a new user with given localpart
719
720 Returns:
721 Deferred: a 2-tuple of (user_id, access_token)
722 """
723 reg = self.hs.get_handlers().registration_handler
724 return reg.register(localpart=localpart)
0 # -*- coding: utf-8 -*-
1 # Copyright 2017 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 from twisted.internet import defer
15
16 from ._base import BaseHandler
17
18 import logging
19
20 logger = logging.getLogger(__name__)
21
22
23 class DeactivateAccountHandler(BaseHandler):
24 """Handler which deals with deactivating user accounts."""
25 def __init__(self, hs):
26 super(DeactivateAccountHandler, self).__init__(hs)
27 self._auth_handler = hs.get_auth_handler()
28 self._device_handler = hs.get_device_handler()
29
30 @defer.inlineCallbacks
31 def deactivate_account(self, user_id):
32 """Deactivate a user's account
33
34 Args:
35 user_id (str): ID of user to be deactivated
36
37 Returns:
38 Deferred
39 """
40 # FIXME: Theoretically there is a race here wherein user resets
41 # password using threepid.
42
43 # first delete any devices belonging to the user, which will also
44 # delete corresponding access tokens.
45 yield self._device_handler.delete_all_devices_for_user(user_id)
46 # then delete any remaining access tokens which weren't associated with
47 # a device.
48 yield self._auth_handler.delete_access_tokens_for_user(user_id)
49
50 yield self.store.user_delete_threepids(user_id)
51 yield self.store.user_set_password_hash(user_id, None)
1313 # limitations under the License.
1414 from synapse.api import errors
1515 from synapse.api.constants import EventTypes
16 from synapse.api.errors import FederationDeniedError
1617 from synapse.util import stringutils
1718 from synapse.util.async import Linearizer
1819 from synapse.util.caches.expiringcache import ExpiringCache
3334
3435 self.hs = hs
3536 self.state = hs.get_state_handler()
37 self._auth_handler = hs.get_auth_handler()
3638 self.federation_sender = hs.get_federation_sender()
37 self.federation = hs.get_replication_layer()
3839
3940 self._edu_updater = DeviceListEduUpdater(hs, self)
4041
41 self.federation.register_edu_handler(
42 federation_registry = hs.get_federation_registry()
43
44 federation_registry.register_edu_handler(
4245 "m.device_list_update", self._edu_updater.incoming_device_list_update,
4346 )
44 self.federation.register_query_handler(
47 federation_registry.register_query_handler(
4548 "user_devices", self.on_federation_query_user_devices,
4649 )
4750
158161 else:
159162 raise
160163
161 yield self.store.user_delete_access_tokens(
164 yield self._auth_handler.delete_access_tokens_for_user(
162165 user_id, device_id=device_id,
163 delete_refresh_tokens=True,
164166 )
165167
166168 yield self.store.delete_e2e_keys_by_device(
170172 yield self.notify_device_update(user_id, [device_id])
171173
172174 @defer.inlineCallbacks
175 def delete_all_devices_for_user(self, user_id, except_device_id=None):
176 """Delete all of the user's devices
177
178 Args:
179 user_id (str):
180 except_device_id (str|None): optional device id which should not
181 be deleted
182
183 Returns:
184 defer.Deferred:
185 """
186 device_map = yield self.store.get_devices_by_user(user_id)
187 device_ids = device_map.keys()
188 if except_device_id is not None:
189 device_ids = [d for d in device_ids if d != except_device_id]
190 yield self.delete_devices(user_id, device_ids)
191
192 @defer.inlineCallbacks
173193 def delete_devices(self, user_id, device_ids):
174194 """ Delete several devices
175195
176196 Args:
177197 user_id (str):
178 device_ids (str): The list of device IDs to delete
198 device_ids (List[str]): The list of device IDs to delete
179199
180200 Returns:
181201 defer.Deferred:
193213 # Delete access tokens and e2e keys for each device. Not optimised as it is not
194214 # considered as part of a critical path.
195215 for device_id in device_ids:
196 yield self.store.user_delete_access_tokens(
216 yield self._auth_handler.delete_access_tokens_for_user(
197217 user_id, device_id=device_id,
198 delete_refresh_tokens=True,
199218 )
200219 yield self.store.delete_e2e_keys_by_device(
201220 user_id=user_id, device_id=device_id
411430
412431 def __init__(self, hs, device_handler):
413432 self.store = hs.get_datastore()
414 self.federation = hs.get_replication_layer()
433 self.federation = hs.get_federation_client()
415434 self.clock = hs.get_clock()
416435 self.device_handler = device_handler
417436
495514 # This makes it more likely that the device lists will
496515 # eventually become consistent.
497516 return
517 except FederationDeniedError as e:
518 logger.info(e)
519 return
498520 except Exception:
499521 # TODO: Remember that we are now out of sync and try again
500522 # later
1616
1717 from twisted.internet import defer
1818
19 from synapse.types import get_domain_from_id
19 from synapse.api.errors import SynapseError
20 from synapse.types import get_domain_from_id, UserID
2021 from synapse.util.stringutils import random_string
2122
2223
3233 """
3334 self.store = hs.get_datastore()
3435 self.notifier = hs.get_notifier()
35 self.is_mine_id = hs.is_mine_id
36 self.is_mine = hs.is_mine
3637 self.federation = hs.get_federation_sender()
3738
38 hs.get_replication_layer().register_edu_handler(
39 hs.get_federation_registry().register_edu_handler(
3940 "m.direct_to_device", self.on_direct_to_device_edu
4041 )
4142
5152 message_type = content["type"]
5253 message_id = content["message_id"]
5354 for user_id, by_device in content["messages"].items():
55 # we use UserID.from_string to catch invalid user ids
56 if not self.is_mine(UserID.from_string(user_id)):
57 logger.warning("Request for keys for non-local user %s",
58 user_id)
59 raise SynapseError(400, "Not a user here")
60
5461 messages_by_device = {
5562 device_id: {
5663 "content": message_content,
7683 local_messages = {}
7784 remote_messages = {}
7885 for user_id, by_device in messages.items():
79 if self.is_mine_id(user_id):
86 # we use UserID.from_string to catch invalid user ids
87 if self.is_mine(UserID.from_string(user_id)):
8088 messages_by_device = {
8189 device_id: {
8290 "content": message_content,
3333
3434 self.state = hs.get_state_handler()
3535 self.appservice_handler = hs.get_application_service_handler()
36
37 self.federation = hs.get_replication_layer()
38 self.federation.register_query_handler(
36 self.event_creation_handler = hs.get_event_creation_handler()
37
38 self.federation = hs.get_federation_client()
39 hs.get_federation_registry().register_query_handler(
3940 "directory", self.on_directory_query
4041 )
4142
248249 def send_room_alias_update_event(self, requester, user_id, room_id):
249250 aliases = yield self.store.get_aliases_for_room(room_id)
250251
251 msg_handler = self.hs.get_handlers().message_handler
252 yield msg_handler.create_and_send_nonmember_event(
252 yield self.event_creation_handler.create_and_send_nonmember_event(
253253 requester,
254254 {
255255 "type": EventTypes.Aliases,
271271 if not alias_event or alias_event.content.get("alias", "") != alias_str:
272272 return
273273
274 msg_handler = self.hs.get_handlers().message_handler
275 yield msg_handler.create_and_send_nonmember_event(
274 yield self.event_creation_handler.create_and_send_nonmember_event(
276275 requester,
277276 {
278277 "type": EventTypes.CanonicalAlias,
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import ujson as json
15 import simplejson as json
1616 import logging
1717
1818 from canonicaljson import encode_canonical_json
1919 from twisted.internet import defer
2020
21 from synapse.api.errors import SynapseError, CodeMessageException
22 from synapse.types import get_domain_from_id
21 from synapse.api.errors import (
22 SynapseError, CodeMessageException, FederationDeniedError,
23 )
24 from synapse.types import get_domain_from_id, UserID
2325 from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
2426 from synapse.util.retryutils import NotRetryingDestination
2527
2931 class E2eKeysHandler(object):
3032 def __init__(self, hs):
3133 self.store = hs.get_datastore()
32 self.federation = hs.get_replication_layer()
34 self.federation = hs.get_federation_client()
3335 self.device_handler = hs.get_device_handler()
34 self.is_mine_id = hs.is_mine_id
36 self.is_mine = hs.is_mine
3537 self.clock = hs.get_clock()
3638
3739 # doesn't really work as part of the generic query API, because the
3840 # query request requires an object POST, but we abuse the
3941 # "query handler" interface.
40 self.federation.register_query_handler(
42 hs.get_federation_registry().register_query_handler(
4143 "client_keys", self.on_federation_query_client_keys
4244 )
4345
6971 remote_queries = {}
7072
7173 for user_id, device_ids in device_keys_query.items():
72 if self.is_mine_id(user_id):
74 # we use UserID.from_string to catch invalid user ids
75 if self.is_mine(UserID.from_string(user_id)):
7376 local_query[user_id] = device_ids
7477 else:
7578 remote_queries[user_id] = device_ids
138141 failures[destination] = {
139142 "status": 503, "message": "Not ready for retry",
140143 }
144 except FederationDeniedError as e:
145 failures[destination] = {
146 "status": 403, "message": "Federation Denied",
147 }
141148 except Exception as e:
142149 # include ConnectionRefused and other errors
143150 failures[destination] = {
169176
170177 result_dict = {}
171178 for user_id, device_ids in query.items():
172 if not self.is_mine_id(user_id):
179 # we use UserID.from_string to catch invalid user ids
180 if not self.is_mine(UserID.from_string(user_id)):
173181 logger.warning("Request for keys for non-local user %s",
174182 user_id)
175183 raise SynapseError(400, "Not a user here")
212220 remote_queries = {}
213221
214222 for user_id, device_keys in query.get("one_time_keys", {}).items():
215 if self.is_mine_id(user_id):
223 # we use UserID.from_string to catch invalid user ids
224 if self.is_mine(UserID.from_string(user_id)):
216225 for device_id, algorithm in device_keys.items():
217226 local_query.append((user_id, device_id, algorithm))
218227 else:
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
2122
2223 from synapse.api.errors import (
2324 AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
25 FederationDeniedError,
2426 )
2527 from synapse.api.constants import EventTypes, Membership, RejectedReason
2628 from synapse.events.validator import EventValidator
6567 self.hs = hs
6668
6769 self.store = hs.get_datastore()
68 self.replication_layer = hs.get_replication_layer()
70 self.replication_layer = hs.get_federation_client()
6971 self.state_handler = hs.get_state_handler()
7072 self.server_name = hs.hostname
7173 self.keyring = hs.get_keyring()
7375 self.is_mine_id = hs.is_mine_id
7476 self.pusher_pool = hs.get_pusherpool()
7577 self.spam_checker = hs.get_spam_checker()
76
77 self.replication_layer.set_handler(self)
78 self.event_creation_handler = hs.get_event_creation_handler()
7879
7980 # When joining a room we need to queue any events for that room up
8081 self.room_queues = {}
226227 state, auth_chain = yield self.replication_layer.get_state_for_room(
227228 origin, pdu.room_id, pdu.event_id,
228229 )
229 except:
230 except Exception:
230231 logger.exception("Failed to get state for event: %s", pdu.event_id)
231232
232233 yield self._process_received_pdu(
460461 def check_match(id):
461462 try:
462463 return server_name == get_domain_from_id(id)
463 except:
464 except Exception:
464465 return False
465466
466467 # Parses mapping `event_id -> (type, state_key) -> state event_id`
498499 continue
499500 try:
500501 domain = get_domain_from_id(ev.state_key)
501 except:
502 except Exception:
502503 continue
503504
504505 if domain != server_name:
737738 joined_domains[dom] = min(d, old_d)
738739 else:
739740 joined_domains[dom] = d
740 except:
741 except Exception:
741742 pass
742743
743744 return sorted(joined_domains.items(), key=lambda d: d[1])
781782 except NotRetryingDestination as e:
782783 logger.info(e.message)
783784 continue
785 except FederationDeniedError as e:
786 logger.info(e)
787 continue
784788 except Exception as e:
785789 logger.exception(
786790 "Failed to backfill from %s because %s",
803807 event_ids = list(extremities.keys())
804808
805809 logger.debug("calling resolve_state_groups in _maybe_backfill")
810 resolve = logcontext.preserve_fn(
811 self.state_handler.resolve_state_groups_for_events
812 )
806813 states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
807 [
808 logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
809 room_id, [e]
810 )
811 for e in event_ids
812 ], consumeErrors=True,
814 [resolve(room_id, [e]) for e in event_ids],
815 consumeErrors=True,
813816 ))
814817 states = dict(zip(event_ids, [s.state for s in states]))
815818
939942 room_creator_user_id="",
940943 is_public=False
941944 )
942 except:
945 except Exception:
943946 # FIXME
944947 pass
945948
10031006 })
10041007
10051008 try:
1006 message_handler = self.hs.get_handlers().message_handler
1007 event, context = yield message_handler._create_new_client_event(
1009 event, context = yield self.event_creation_handler.create_new_client_event(
10081010 builder=builder,
10091011 )
10101012 except AuthError as e:
12441246 "state_key": user_id,
12451247 })
12461248
1247 message_handler = self.hs.get_handlers().message_handler
1248 event, context = yield message_handler._create_new_client_event(
1249 event, context = yield self.event_creation_handler.create_new_client_event(
12491250 builder=builder,
12501251 )
12511252
14431444 auth_events=auth_events,
14441445 )
14451446
1446 if not event.internal_metadata.is_outlier() and not backfilled:
1447 yield self.action_generator.handle_push_actions_for_event(
1448 event, context
1449 )
1450
1451 event_stream_id, max_stream_id = yield self.store.persist_event(
1452 event,
1453 context=context,
1454 backfilled=backfilled,
1455 )
1447 try:
1448 if not event.internal_metadata.is_outlier() and not backfilled:
1449 yield self.action_generator.handle_push_actions_for_event(
1450 event, context
1451 )
1452
1453 event_stream_id, max_stream_id = yield self.store.persist_event(
1454 event,
1455 context=context,
1456 backfilled=backfilled,
1457 )
1458 except: # noqa: E722, as we reraise the exception this is fine.
1459 # Ensure that we actually remove the entries in the push actions
1460 # staging area
1461 logcontext.preserve_fn(
1462 self.store.remove_push_actions_from_staging
1463 )(event.event_id)
1464 raise
14561465
14571466 if not backfilled:
14581467 # this intentionally does not yield: we don't care about the result
17051714 @defer.inlineCallbacks
17061715 @log_function
17071716 def do_auth(self, origin, event, context, auth_events):
1717 """
1718
1719 Args:
1720 origin (str):
1721 event (synapse.events.FrozenEvent):
1722 context (synapse.events.snapshot.EventContext):
1723 auth_events (dict[(str, str)->str]):
1724
1725 Returns:
1726 defer.Deferred[None]
1727 """
17081728 # Check if we have all the auth events.
17091729 current_state = set(e.event_id for e in auth_events.values())
17101730 event_auth_events = set(e_id for e_id, _ in event.auth_events)
17741794 [e_id for e_id, _ in event.auth_events]
17751795 )
17761796 seen_events = set(have_events.keys())
1777 except:
1797 except Exception:
17781798 # FIXME:
17791799 logger.exception("Failed to get auth chain")
17801800
18161836 current_state = set(e.event_id for e in auth_events.values())
18171837 different_auth = event_auth_events - current_state
18181838
1819 context.current_state_ids = dict(context.current_state_ids)
1820 context.current_state_ids.update({
1821 k: a.event_id for k, a in auth_events.items()
1822 if k != event_key
1823 })
1824 context.prev_state_ids = dict(context.prev_state_ids)
1825 context.prev_state_ids.update({
1826 k: a.event_id for k, a in auth_events.items()
1827 })
1828 context.state_group = self.store.get_next_state_group()
1839 yield self._update_context_for_auth_events(
1840 event, context, auth_events, event_key,
1841 )
18291842
18301843 if different_auth and not event.internal_metadata.is_outlier():
18311844 logger.info("Different auth after resolution: %s", different_auth)
18981911 except AuthError:
18991912 pass
19001913
1901 except:
1914 except Exception:
19021915 # FIXME:
19031916 logger.exception("Failed to query auth chain")
19041917
19051918 # 4. Look at rejects and their proofs.
19061919 # TODO.
19071920
1908 context.current_state_ids = dict(context.current_state_ids)
1909 context.current_state_ids.update({
1910 k: a.event_id for k, a in auth_events.items()
1911 if k != event_key
1912 })
1913 context.prev_state_ids = dict(context.prev_state_ids)
1914 context.prev_state_ids.update({
1915 k: a.event_id for k, a in auth_events.items()
1916 })
1917 context.state_group = self.store.get_next_state_group()
1921 yield self._update_context_for_auth_events(
1922 event, context, auth_events, event_key,
1923 )
19181924
19191925 try:
19201926 self.auth.check(event, auth_events=auth_events)
19231929 raise e
19241930
19251931 @defer.inlineCallbacks
1932 def _update_context_for_auth_events(self, event, context, auth_events,
1933 event_key):
1934 """Update the state_ids in an event context after auth event resolution,
1935 storing the changes as a new state group.
1936
1937 Args:
1938 event (Event): The event we're handling the context for
1939
1940 context (synapse.events.snapshot.EventContext): event context
1941 to be updated
1942
1943 auth_events (dict[(str, str)->str]): Events to update in the event
1944 context.
1945
1946 event_key ((str, str)): (type, state_key) for the current event.
1947 this will not be included in the current_state in the context.
1948 """
1949 state_updates = {
1950 k: a.event_id for k, a in auth_events.iteritems()
1951 if k != event_key
1952 }
1953 context.current_state_ids = dict(context.current_state_ids)
1954 context.current_state_ids.update(state_updates)
1955 if context.delta_ids is not None:
1956 context.delta_ids = dict(context.delta_ids)
1957 context.delta_ids.update(state_updates)
1958 context.prev_state_ids = dict(context.prev_state_ids)
1959 context.prev_state_ids.update({
1960 k: a.event_id for k, a in auth_events.iteritems()
1961 })
1962 context.state_group = yield self.store.store_state_group(
1963 event.event_id,
1964 event.room_id,
1965 prev_group=context.prev_group,
1966 delta_ids=context.delta_ids,
1967 current_state_ids=context.current_state_ids,
1968 )
1969
1970 @defer.inlineCallbacks
19261971 def construct_auth_difference(self, local_auth, remote_auth):
19271972 """ Given a local and remote auth chain, find the differences. This
19281973 assumes that we have already processed all events in remote_auth
19652010 def get_next(it, opt=None):
19662011 try:
19672012 return it.next()
1968 except:
2013 except Exception:
19692014 return opt
19702015
19712016 current_local = get_next(local_iter)
20902135 if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
20912136 builder = self.event_builder_factory.new(event_dict)
20922137 EventValidator().validate_new(builder)
2093 message_handler = self.hs.get_handlers().message_handler
2094 event, context = yield message_handler._create_new_client_event(
2138 event, context = yield self.event_creation_handler.create_new_client_event(
20952139 builder=builder
20962140 )
20972141
21062150 raise e
21072151
21082152 yield self._check_signature(event, context)
2109 member_handler = self.hs.get_handlers().room_member_handler
2153 member_handler = self.hs.get_room_member_handler()
21102154 yield member_handler.send_membership_event(None, event, context)
21112155 else:
21122156 destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
21292173 """
21302174 builder = self.event_builder_factory.new(event_dict)
21312175
2132 message_handler = self.hs.get_handlers().message_handler
2133 event, context = yield message_handler._create_new_client_event(
2176 event, context = yield self.event_creation_handler.create_new_client_event(
21342177 builder=builder,
21352178 )
21362179
21512194 # TODO: Make sure the signatures actually are correct.
21522195 event.signatures.update(returned_invite.signatures)
21532196
2154 member_handler = self.hs.get_handlers().room_member_handler
2197 member_handler = self.hs.get_room_member_handler()
21552198 yield member_handler.send_membership_event(None, event, context)
21562199
21572200 @defer.inlineCallbacks
21802223
21812224 builder = self.event_builder_factory.new(event_dict)
21822225 EventValidator().validate_new(builder)
2183 message_handler = self.hs.get_handlers().message_handler
2184 event, context = yield message_handler._create_new_client_event(builder=builder)
2226 event, context = yield self.event_creation_handler.create_new_client_event(
2227 builder=builder,
2228 )
21852229 defer.returnValue((event, context))
21862230
21872231 @defer.inlineCallbacks
7070 get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
7171
7272 add_room_to_group = _create_rerouter("add_room_to_group")
73 update_room_in_group = _create_rerouter("update_room_in_group")
7374 remove_room_from_group = _create_rerouter("remove_room_from_group")
7475
7576 update_group_summary_room = _create_rerouter("update_group_summary_room")
373374 def get_publicised_groups_for_user(self, user_id):
374375 if self.hs.is_mine_id(user_id):
375376 result = yield self.store.get_publicised_groups_for_user(user_id)
377
378 # Check AS associated groups for this user - this depends on the
379 # RegExps in the AS registration file (under `users`)
380 for app_service in self.store.get_app_services():
381 result.extend(app_service.get_groups_for_user(user_id))
382
376383 defer.returnValue({"groups": result})
377384 else:
378 result = yield self.transport_client.get_publicised_groups_for_user(
379 get_domain_from_id(user_id), user_id
380 )
385 bulk_result = yield self.transport_client.bulk_get_publicised_groups(
386 get_domain_from_id(user_id), [user_id],
387 )
388 result = bulk_result.get("users", {}).get(user_id)
381389 # TODO: Verify attestations
382 defer.returnValue(result)
390 defer.returnValue({"groups": result})
383391
384392 @defer.inlineCallbacks
385393 def bulk_get_publicised_groups(self, user_ids, proxy=True):
413421 uid
414422 )
415423
424 # Check AS associated groups for this user - this depends on the
425 # RegExps in the AS registration file (under `users`)
426 for app_service in self.store.get_app_services():
427 results[uid].extend(app_service.get_groups_for_user(uid))
428
416429 defer.returnValue({"users": results})
2626 from synapse.util import unwrapFirstError
2727 from synapse.util.async import concurrently_execute
2828 from synapse.util.caches.snapshot_cache import SnapshotCache
29 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
29 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
3030 from synapse.visibility import filter_events_for_client
3131
3232 from ._base import BaseHandler
162162 lambda states: states[event.event_id]
163163 )
164164
165 (messages, token), current_state = yield preserve_context_over_deferred(
165 (messages, token), current_state = yield make_deferred_yieldable(
166166 defer.gatherResults(
167167 [
168168 preserve_fn(self.store.get_recent_events_for_room)(
213213 })
214214
215215 d["account_data"] = account_data_events
216 except:
216 except Exception:
217217 logger.exception("Failed to get snapshot")
218218
219219 yield concurrently_execute(handle_room, room_list, 10)
00 # -*- coding: utf-8 -*-
11 # Copyright 2014 - 2016 OpenMarket Ltd
2 # Copyright 2017 New Vector Ltd
2 # Copyright 2017 - 2018 New Vector Ltd
33 #
44 # Licensed under the Apache License, Version 2.0 (the "License");
55 # you may not use this file except in compliance with the License.
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15 from twisted.internet import defer
15 from twisted.internet import defer, reactor
16 from twisted.python.failure import Failure
1617
1718 from synapse.api.constants import EventTypes, Membership
1819 from synapse.api.errors import AuthError, Codes, SynapseError
2324 UserID, RoomAlias, RoomStreamToken,
2425 )
2526 from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
26 from synapse.util.logcontext import preserve_fn
27 from synapse.util.logcontext import preserve_fn, run_in_background
2728 from synapse.util.metrics import measure_func
2829 from synapse.util.frozenutils import unfreeze
30 from synapse.util.stringutils import random_string
2931 from synapse.visibility import filter_events_for_client
32 from synapse.replication.http.send_event import send_event_to_master
3033
3134 from ._base import BaseHandler
3235
3437
3538 import logging
3639 import random
37 import ujson
40 import simplejson
3841
3942 logger = logging.getLogger(__name__)
43
44
45 class PurgeStatus(object):
46 """Object tracking the status of a purge request
47
48 This class contains information on the progress of a purge request, for
49 return by get_purge_status.
50
51 Attributes:
52 status (int): Tracks whether this request has completed. One of
53 STATUS_{ACTIVE,COMPLETE,FAILED}
54 """
55
56 STATUS_ACTIVE = 0
57 STATUS_COMPLETE = 1
58 STATUS_FAILED = 2
59
60 STATUS_TEXT = {
61 STATUS_ACTIVE: "active",
62 STATUS_COMPLETE: "complete",
63 STATUS_FAILED: "failed",
64 }
65
66 def __init__(self):
67 self.status = PurgeStatus.STATUS_ACTIVE
68
69 def asdict(self):
70 return {
71 "status": PurgeStatus.STATUS_TEXT[self.status]
72 }
4073
4174
4275 class MessageHandler(BaseHandler):
4679 self.hs = hs
4780 self.state = hs.get_state_handler()
4881 self.clock = hs.get_clock()
49 self.validator = EventValidator()
50 self.profile_handler = hs.get_profile_handler()
5182
5283 self.pagination_lock = ReadWriteLock()
53
54 self.pusher_pool = hs.get_pusherpool()
55
56 # We arbitrarily limit concurrent event creation for a room to 5.
57 # This is to stop us from diverging history *too* much.
58 self.limiter = Limiter(max_count=5)
59
60 self.action_generator = hs.get_action_generator()
61
62 self.spam_checker = hs.get_spam_checker()
63
64 @defer.inlineCallbacks
65 def purge_history(self, room_id, event_id):
66 event = yield self.store.get_event(event_id)
67
68 if event.room_id != room_id:
69 raise SynapseError(400, "Event is for wrong room.")
70
71 depth = event.depth
72
73 with (yield self.pagination_lock.write(room_id)):
74 yield self.store.delete_old_state(room_id, depth)
84 self._purges_in_progress_by_room = set()
85 # map from purge id to PurgeStatus
86 self._purges_by_id = {}
87
88 def start_purge_history(self, room_id, topological_ordering,
89 delete_local_events=False):
90 """Start off a history purge on a room.
91
92 Args:
93 room_id (str): The room to purge from
94
95 topological_ordering (int): minimum topo ordering to preserve
96 delete_local_events (bool): True to delete local events as well as
97 remote ones
98
99 Returns:
100 str: unique ID for this purge transaction.
101 """
102 if room_id in self._purges_in_progress_by_room:
103 raise SynapseError(
104 400,
105 "History purge already in progress for %s" % (room_id, ),
106 )
107
108 purge_id = random_string(16)
109
110 # we log the purge_id here so that it can be tied back to the
111 # request id in the log lines.
112 logger.info("[purge] starting purge_id %s", purge_id)
113
114 self._purges_by_id[purge_id] = PurgeStatus()
115 run_in_background(
116 self._purge_history,
117 purge_id, room_id, topological_ordering, delete_local_events,
118 )
119 return purge_id
120
121 @defer.inlineCallbacks
122 def _purge_history(self, purge_id, room_id, topological_ordering,
123 delete_local_events):
124 """Carry out a history purge on a room.
125
126 Args:
127 purge_id (str): The id for this purge
128 room_id (str): The room to purge from
129 topological_ordering (int): minimum topo ordering to preserve
130 delete_local_events (bool): True to delete local events as well as
131 remote ones
132
133 Returns:
134 Deferred
135 """
136 self._purges_in_progress_by_room.add(room_id)
137 try:
138 with (yield self.pagination_lock.write(room_id)):
139 yield self.store.purge_history(
140 room_id, topological_ordering, delete_local_events,
141 )
142 logger.info("[purge] complete")
143 self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
144 except Exception:
145 logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
146 self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
147 finally:
148 self._purges_in_progress_by_room.discard(room_id)
149
150 # remove the purge from the list 24 hours after it completes
151 def clear_purge():
152 del self._purges_by_id[purge_id]
153 reactor.callLater(24 * 3600, clear_purge)
154
155 def get_purge_status(self, purge_id):
156 """Get the current status of an active purge
157
158 Args:
159 purge_id (str): purge_id returned by start_purge_history
160
161 Returns:
162 PurgeStatus|None
163 """
164 return self._purges_by_id.get(purge_id)
75165
76166 @defer.inlineCallbacks
77167 def get_messages(self, requester, room_id=None, pagin_config=None,
180270 }
181271
182272 defer.returnValue(chunk)
183
184 @defer.inlineCallbacks
185 def create_event(self, requester, event_dict, token_id=None, txn_id=None,
186 prev_event_ids=None):
187 """
188 Given a dict from a client, create a new event.
189
190 Creates an FrozenEvent object, filling out auth_events, prev_events,
191 etc.
192
193 Adds display names to Join membership events.
194
195 Args:
196 requester
197 event_dict (dict): An entire event
198 token_id (str)
199 txn_id (str)
200 prev_event_ids (list): The prev event ids to use when creating the event
201
202 Returns:
203 Tuple of created event (FrozenEvent), Context
204 """
205 builder = self.event_builder_factory.new(event_dict)
206
207 with (yield self.limiter.queue(builder.room_id)):
208 self.validator.validate_new(builder)
209
210 if builder.type == EventTypes.Member:
211 membership = builder.content.get("membership", None)
212 target = UserID.from_string(builder.state_key)
213
214 if membership in {Membership.JOIN, Membership.INVITE}:
215 # If event doesn't include a display name, add one.
216 profile = self.profile_handler
217 content = builder.content
218
219 try:
220 if "displayname" not in content:
221 content["displayname"] = yield profile.get_displayname(target)
222 if "avatar_url" not in content:
223 content["avatar_url"] = yield profile.get_avatar_url(target)
224 except Exception as e:
225 logger.info(
226 "Failed to get profile information for %r: %s",
227 target, e
228 )
229
230 if token_id is not None:
231 builder.internal_metadata.token_id = token_id
232
233 if txn_id is not None:
234 builder.internal_metadata.txn_id = txn_id
235
236 event, context = yield self._create_new_client_event(
237 builder=builder,
238 requester=requester,
239 prev_event_ids=prev_event_ids,
240 )
241
242 defer.returnValue((event, context))
243
244 @defer.inlineCallbacks
245 def send_nonmember_event(self, requester, event, context, ratelimit=True):
246 """
247 Persists and notifies local clients and federation of an event.
248
249 Args:
250 event (FrozenEvent) the event to send.
251 context (Context) the context of the event.
252 ratelimit (bool): Whether to rate limit this send.
253 is_guest (bool): Whether the sender is a guest.
254 """
255 if event.type == EventTypes.Member:
256 raise SynapseError(
257 500,
258 "Tried to send member event through non-member codepath"
259 )
260
261 # We check here if we are currently being rate limited, so that we
262 # don't do unnecessary work. We check again just before we actually
263 # send the event.
264 yield self.ratelimit(requester, update=False)
265
266 user = UserID.from_string(event.sender)
267
268 assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
269
270 if event.is_state():
271 prev_state = yield self.deduplicate_state_event(event, context)
272 if prev_state is not None:
273 defer.returnValue(prev_state)
274
275 yield self.handle_new_client_event(
276 requester=requester,
277 event=event,
278 context=context,
279 ratelimit=ratelimit,
280 )
281
282 if event.type == EventTypes.Message:
283 presence = self.hs.get_presence_handler()
284 # We don't want to block sending messages on any presence code. This
285 # matters as sometimes presence code can take a while.
286 preserve_fn(presence.bump_presence_active_time)(user)
287
288 @defer.inlineCallbacks
289 def deduplicate_state_event(self, event, context):
290 """
291 Checks whether event is in the latest resolved state in context.
292
293 If so, returns the version of the event in context.
294 Otherwise, returns None.
295 """
296 prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
297 prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
298 if not prev_event:
299 return
300
301 if prev_event and event.user_id == prev_event.user_id:
302 prev_content = encode_canonical_json(prev_event.content)
303 next_content = encode_canonical_json(event.content)
304 if prev_content == next_content:
305 defer.returnValue(prev_event)
306 return
307
308 @defer.inlineCallbacks
309 def create_and_send_nonmember_event(
310 self,
311 requester,
312 event_dict,
313 ratelimit=True,
314 txn_id=None
315 ):
316 """
317 Creates an event, then sends it.
318
319 See self.create_event and self.send_nonmember_event.
320 """
321 event, context = yield self.create_event(
322 requester,
323 event_dict,
324 token_id=requester.access_token_id,
325 txn_id=txn_id
326 )
327
328 spam_error = self.spam_checker.check_event_for_spam(event)
329 if spam_error:
330 if not isinstance(spam_error, basestring):
331 spam_error = "Spam is not permitted here"
332 raise SynapseError(
333 403, spam_error, Codes.FORBIDDEN
334 )
335
336 yield self.send_nonmember_event(
337 requester,
338 event,
339 context,
340 ratelimit=ratelimit,
341 )
342 defer.returnValue(event)
343273
344274 @defer.inlineCallbacks
345275 def get_room_data(self, user_id=None, room_id=None,
469399 for user_id, profile in users_with_profile.iteritems()
470400 })
471401
472 @measure_func("_create_new_client_event")
473 @defer.inlineCallbacks
474 def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
402
403 class EventCreationHandler(object):
404 def __init__(self, hs):
405 self.hs = hs
406 self.auth = hs.get_auth()
407 self.store = hs.get_datastore()
408 self.state = hs.get_state_handler()
409 self.clock = hs.get_clock()
410 self.validator = EventValidator()
411 self.profile_handler = hs.get_profile_handler()
412 self.event_builder_factory = hs.get_event_builder_factory()
413 self.server_name = hs.hostname
414 self.ratelimiter = hs.get_ratelimiter()
415 self.notifier = hs.get_notifier()
416 self.config = hs.config
417
418 self.http_client = hs.get_simple_http_client()
419
420 # This is only used to get at ratelimit function, and maybe_kick_guest_users
421 self.base_handler = BaseHandler(hs)
422
423 self.pusher_pool = hs.get_pusherpool()
424
425 # We arbitrarily limit concurrent event creation for a room to 5.
426 # This is to stop us from diverging history *too* much.
427 self.limiter = Limiter(max_count=5)
428
429 self.action_generator = hs.get_action_generator()
430
431 self.spam_checker = hs.get_spam_checker()
432
433 @defer.inlineCallbacks
434 def create_event(self, requester, event_dict, token_id=None, txn_id=None,
435 prev_event_ids=None):
436 """
437 Given a dict from a client, create a new event.
438
439 Creates an FrozenEvent object, filling out auth_events, prev_events,
440 etc.
441
442 Adds display names to Join membership events.
443
444 Args:
445 requester
446 event_dict (dict): An entire event
447 token_id (str)
448 txn_id (str)
449 prev_event_ids (list): The prev event ids to use when creating the event
450
451 Returns:
452 Tuple of created event (FrozenEvent), Context
453 """
454 builder = self.event_builder_factory.new(event_dict)
455
456 with (yield self.limiter.queue(builder.room_id)):
457 self.validator.validate_new(builder)
458
459 if builder.type == EventTypes.Member:
460 membership = builder.content.get("membership", None)
461 target = UserID.from_string(builder.state_key)
462
463 if membership in {Membership.JOIN, Membership.INVITE}:
464 # If event doesn't include a display name, add one.
465 profile = self.profile_handler
466 content = builder.content
467
468 try:
469 if "displayname" not in content:
470 content["displayname"] = yield profile.get_displayname(target)
471 if "avatar_url" not in content:
472 content["avatar_url"] = yield profile.get_avatar_url(target)
473 except Exception as e:
474 logger.info(
475 "Failed to get profile information for %r: %s",
476 target, e
477 )
478
479 if token_id is not None:
480 builder.internal_metadata.token_id = token_id
481
482 if txn_id is not None:
483 builder.internal_metadata.txn_id = txn_id
484
485 event, context = yield self.create_new_client_event(
486 builder=builder,
487 requester=requester,
488 prev_event_ids=prev_event_ids,
489 )
490
491 defer.returnValue((event, context))
492
493 @defer.inlineCallbacks
494 def send_nonmember_event(self, requester, event, context, ratelimit=True):
495 """
496 Persists and notifies local clients and federation of an event.
497
498 Args:
499 event (FrozenEvent) the event to send.
500 context (Context) the context of the event.
501 ratelimit (bool): Whether to rate limit this send.
502 is_guest (bool): Whether the sender is a guest.
503 """
504 if event.type == EventTypes.Member:
505 raise SynapseError(
506 500,
507 "Tried to send member event through non-member codepath"
508 )
509
510 user = UserID.from_string(event.sender)
511
512 assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
513
514 if event.is_state():
515 prev_state = yield self.deduplicate_state_event(event, context)
516 if prev_state is not None:
517 defer.returnValue(prev_state)
518
519 yield self.handle_new_client_event(
520 requester=requester,
521 event=event,
522 context=context,
523 ratelimit=ratelimit,
524 )
525
526 @defer.inlineCallbacks
527 def deduplicate_state_event(self, event, context):
528 """
529 Checks whether event is in the latest resolved state in context.
530
531 If so, returns the version of the event in context.
532 Otherwise, returns None.
533 """
534 prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
535 prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
536 if not prev_event:
537 return
538
539 if prev_event and event.user_id == prev_event.user_id:
540 prev_content = encode_canonical_json(prev_event.content)
541 next_content = encode_canonical_json(event.content)
542 if prev_content == next_content:
543 defer.returnValue(prev_event)
544 return
545
546 @defer.inlineCallbacks
547 def create_and_send_nonmember_event(
548 self,
549 requester,
550 event_dict,
551 ratelimit=True,
552 txn_id=None
553 ):
554 """
555 Creates an event, then sends it.
556
557 See self.create_event and self.send_nonmember_event.
558 """
559 event, context = yield self.create_event(
560 requester,
561 event_dict,
562 token_id=requester.access_token_id,
563 txn_id=txn_id
564 )
565
566 spam_error = self.spam_checker.check_event_for_spam(event)
567 if spam_error:
568 if not isinstance(spam_error, basestring):
569 spam_error = "Spam is not permitted here"
570 raise SynapseError(
571 403, spam_error, Codes.FORBIDDEN
572 )
573
574 yield self.send_nonmember_event(
575 requester,
576 event,
577 context,
578 ratelimit=ratelimit,
579 )
580 defer.returnValue(event)
581
582 @measure_func("create_new_client_event")
583 @defer.inlineCallbacks
584 def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
475585 if prev_event_ids:
476586 prev_events = yield self.store.add_event_hashes(prev_event_ids)
477587 prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
508618 builder.prev_events = prev_events
509619 builder.depth = depth
510620
511 state_handler = self.state_handler
512
513 context = yield state_handler.compute_event_context(builder)
621 context = yield self.state.compute_event_context(builder)
514622 if requester:
515623 context.app_service = requester.app_service
516624
545653 event,
546654 context,
547655 ratelimit=True,
548 extra_users=[]
656 extra_users=[],
549657 ):
550 # We now need to go and hit out to wherever we need to hit out to.
551
552 if ratelimit:
553 yield self.ratelimit(requester)
658 """Processes a new event. This includes checking auth, persisting it,
659 notifying users, sending to remote servers, etc.
660
661 If called from a worker will hit out to the master process for final
662 processing.
663
664 Args:
665 requester (Requester)
666 event (FrozenEvent)
667 context (EventContext)
668 ratelimit (bool)
669 extra_users (list(UserID)): Any extra users to notify about event
670 """
554671
555672 try:
556673 yield self.auth.check_from_context(event, context)
560677
561678 # Ensure that we can round trip before trying to persist in db
562679 try:
563 dump = ujson.dumps(unfreeze(event.content))
564 ujson.loads(dump)
565 except:
680 dump = simplejson.dumps(unfreeze(event.content))
681 simplejson.loads(dump)
682 except Exception:
566683 logger.exception("Failed to encode content: %r", event.content)
567684 raise
568685
569 yield self.maybe_kick_guest_users(event, context)
686 yield self.action_generator.handle_push_actions_for_event(
687 event, context
688 )
689
690 try:
691 # If we're a worker we need to hit out to the master.
692 if self.config.worker_app:
693 yield send_event_to_master(
694 self.http_client,
695 host=self.config.worker_replication_host,
696 port=self.config.worker_replication_http_port,
697 requester=requester,
698 event=event,
699 context=context,
700 ratelimit=ratelimit,
701 extra_users=extra_users,
702 )
703 return
704
705 yield self.persist_and_notify_client_event(
706 requester,
707 event,
708 context,
709 ratelimit=ratelimit,
710 extra_users=extra_users,
711 )
712 except: # noqa: E722, as we reraise the exception this is fine.
713 # Ensure that we actually remove the entries in the push actions
714 # staging area, if we calculated them.
715 preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
716 raise
717
718 @defer.inlineCallbacks
719 def persist_and_notify_client_event(
720 self,
721 requester,
722 event,
723 context,
724 ratelimit=True,
725 extra_users=[],
726 ):
727 """Called when we have fully built the event, have already
728 calculated the push actions for the event, and checked auth.
729
730 This should only be run on master.
731 """
732 assert not self.config.worker_app
733
734 if ratelimit:
735 yield self.base_handler.ratelimit(requester)
736
737 yield self.base_handler.maybe_kick_guest_users(event, context)
570738
571739 if event.type == EventTypes.CanonicalAlias:
572740 # Check the alias is acually valid (at this time at least)
659827 "Changing the room create event is forbidden",
660828 )
661829
662 yield self.action_generator.handle_push_actions_for_event(
663 event, context
664 )
665
666830 (event_stream_id, max_stream_id) = yield self.store.persist_event(
667831 event, context=context
668832 )
682846 )
683847
684848 preserve_fn(_notify)()
849
850 if event.type == EventTypes.Message:
851 presence = self.hs.get_presence_handler()
852 # We don't want to block sending messages on any presence code. This
853 # matters as sometimes presence code can take a while.
854 preserve_fn(presence.bump_presence_active_time)(requester.user)
9292 self.store = hs.get_datastore()
9393 self.wheel_timer = WheelTimer()
9494 self.notifier = hs.get_notifier()
95 self.replication = hs.get_replication_layer()
9695 self.federation = hs.get_federation_sender()
9796
9897 self.state = hs.get_state_handler()
9998
100 self.replication.register_edu_handler(
99 federation_registry = hs.get_federation_registry()
100
101 federation_registry.register_edu_handler(
101102 "m.presence", self.incoming_presence
102103 )
103 self.replication.register_edu_handler(
104 federation_registry.register_edu_handler(
104105 "m.presence_invite",
105106 lambda origin, content: self.invite_presence(
106107 observed_user=UserID.from_string(content["observed_user"]),
107108 observer_user=UserID.from_string(content["observer_user"]),
108109 )
109110 )
110 self.replication.register_edu_handler(
111 federation_registry.register_edu_handler(
111112 "m.presence_accept",
112113 lambda origin, content: self.accept_presence(
113114 observed_user=UserID.from_string(content["observed_user"]),
114115 observer_user=UserID.from_string(content["observer_user"]),
115116 )
116117 )
117 self.replication.register_edu_handler(
118 federation_registry.register_edu_handler(
118119 "m.presence_deny",
119120 lambda origin, content: self.deny_presence(
120121 observed_user=UserID.from_string(content["observed_user"]),
363364 )
364365
365366 preserve_fn(self._update_states)(changes)
366 except:
367 except Exception:
367368 logger.exception("Exception in _handle_timeouts loop")
368369
369370 @defer.inlineCallbacks
11981199 )
11991200 changed = True
12001201 else:
1201 # We expect to be poked occaisonally by the other side.
1202 # We expect to be poked occasionally by the other side.
12021203 # This is to protect against forgetful/buggy servers, so that
12031204 # no one gets stuck online forever.
12041205 if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
1616
1717 from twisted.internet import defer
1818
19 import synapse.types
2019 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
2120 from synapse.types import UserID, get_domain_from_id
2221 from ._base import BaseHandler
3130 def __init__(self, hs):
3231 super(ProfileHandler, self).__init__(hs)
3332
34 self.federation = hs.get_replication_layer()
35 self.federation.register_query_handler(
33 self.federation = hs.get_federation_client()
34 hs.get_federation_registry().register_query_handler(
3635 "profile", self.on_profile_query
3736 )
3837
39 self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
38 self.user_directory_handler = hs.get_user_directory_handler()
39
40 if hs.config.worker_app is None:
41 self.clock.looping_call(
42 self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
43 )
4044
4145 @defer.inlineCallbacks
4246 def get_profile(self, user_id):
117121 logger.exception("Failed to get displayname")
118122
119123 raise
120 except:
124 except Exception:
121125 logger.exception("Failed to get displayname")
122126 else:
123127 defer.returnValue(result["displayname"])
139143 target_user.localpart, new_displayname
140144 )
141145
142 yield self._update_join_states(requester)
146 if self.hs.config.user_directory_search_all_users:
147 profile = yield self.store.get_profileinfo(target_user.localpart)
148 yield self.user_directory_handler.handle_local_profile_change(
149 target_user.to_string(), profile
150 )
151
152 yield self._update_join_states(requester, target_user)
143153
144154 @defer.inlineCallbacks
145155 def get_avatar_url(self, target_user):
164174 if e.code != 404:
165175 logger.exception("Failed to get avatar_url")
166176 raise
167 except:
177 except Exception:
168178 logger.exception("Failed to get avatar_url")
169179
170180 defer.returnValue(result["avatar_url"])
183193 target_user.localpart, new_avatar_url
184194 )
185195
186 yield self._update_join_states(requester)
196 if self.hs.config.user_directory_search_all_users:
197 profile = yield self.store.get_profileinfo(target_user.localpart)
198 yield self.user_directory_handler.handle_local_profile_change(
199 target_user.to_string(), profile
200 )
201
202 yield self._update_join_states(requester, target_user)
187203
188204 @defer.inlineCallbacks
189205 def on_profile_query(self, args):
208224 defer.returnValue(response)
209225
210226 @defer.inlineCallbacks
211 def _update_join_states(self, requester):
212 user = requester.user
213 if not self.hs.is_mine(user):
227 def _update_join_states(self, requester, target_user):
228 if not self.hs.is_mine(target_user):
214229 return
215230
216231 yield self.ratelimit(requester)
217232
218233 room_ids = yield self.store.get_rooms_for_user(
219 user.to_string(),
234 target_user.to_string(),
220235 )
221236
222237 for room_id in room_ids:
223 handler = self.hs.get_handlers().room_member_handler
224 try:
225 # Assume the user isn't a guest because we don't let guests set
226 # profile or avatar data.
227 # XXX why are we recreating `requester` here for each room?
228 # what was wrong with the `requester` we were passed?
229 requester = synapse.types.create_requester(user)
238 handler = self.hs.get_room_member_handler()
239 try:
240 # Assume the target_user isn't a guest,
241 # because we don't let guests set profile or avatar data.
230242 yield handler.update_membership(
231243 requester,
232 user,
244 target_user,
233245 room_id,
234246 "join", # We treat a profile update like a join.
235247 ratelimit=False, # Try to hide that these events aren't atomic.
265277 },
266278 ignore_backoff=True,
267279 )
268 except:
280 except Exception:
269281 logger.exception("Failed to get avatar_url")
270282
271283 yield self.store.update_remote_profile_cache(
4040 """
4141
4242 with (yield self.read_marker_linearizer.queue((room_id, user_id))):
43 account_data = yield self.store.get_account_data_for_room(user_id, room_id)
44
45 existing_read_marker = account_data.get("m.fully_read", None)
43 existing_read_marker = yield self.store.get_account_data_for_room_and_type(
44 user_id, room_id, "m.fully_read",
45 )
4646
4747 should_update = True
4848
3434 self.store = hs.get_datastore()
3535 self.hs = hs
3636 self.federation = hs.get_federation_sender()
37 hs.get_replication_layer().register_edu_handler(
37 hs.get_federation_registry().register_edu_handler(
3838 "m.receipt", self._received_remote_receipt
3939 )
4040 self.clock = self.hs.get_clock()
1414
1515 """Contains functions for registering clients."""
1616 import logging
17 import urllib
1817
1918 from twisted.internet import defer
2019
2221 AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
2322 )
2423 from synapse.http.client import CaptchaServerHttpClient
24 from synapse import types
2525 from synapse.types import UserID
2626 from synapse.util.async import run_on_reactor
27 from synapse.util.threepids import check_3pid_allowed
2728 from ._base import BaseHandler
2829
2930 logger = logging.getLogger(__name__)
3536 super(RegistrationHandler, self).__init__(hs)
3637
3738 self.auth = hs.get_auth()
39 self._auth_handler = hs.get_auth_handler()
3840 self.profile_handler = hs.get_profile_handler()
41 self.user_directory_handler = hs.get_user_directory_handler()
3942 self.captcha_client = CaptchaServerHttpClient(hs)
4043
4144 self._next_generated_user_id = None
4548 @defer.inlineCallbacks
4649 def check_username(self, localpart, guest_access_token=None,
4750 assigned_user_id=None):
48 yield run_on_reactor()
49
50 if urllib.quote(localpart.encode('utf-8')) != localpart:
51 if types.contains_invalid_mxid_characters(localpart):
5152 raise SynapseError(
5253 400,
53 "User ID can only contain characters a-z, 0-9, or '_-./'",
54 "User ID can only contain characters a-z, 0-9, or '=_-./'",
5455 Codes.INVALID_USERNAME
5556 )
5657
8081 "A different user ID has already been registered for this session",
8182 )
8283
83 yield self.check_user_id_not_appservice_exclusive(user_id)
84 self.check_user_id_not_appservice_exclusive(user_id)
8485
8586 users = yield self.store.get_users_by_id_case_insensitive(user_id)
8687 if users:
130131 yield run_on_reactor()
131132 password_hash = None
132133 if password:
133 password_hash = self.auth_handler().hash(password)
134 password_hash = yield self.auth_handler().hash(password)
134135
135136 if localpart:
136137 yield self.check_username(localpart, guest_access_token=guest_access_token)
165166 ),
166167 admin=admin,
167168 )
169
170 if self.hs.config.user_directory_search_all_users:
171 profile = yield self.store.get_profileinfo(localpart)
172 yield self.user_directory_handler.handle_local_profile_change(
173 user_id, profile
174 )
175
168176 else:
169177 # autogen a sequential user ID
170178 attempts = 0
253261 """
254262 Registers email_id as SAML2 Based Auth.
255263 """
256 if urllib.quote(localpart) != localpart:
264 if types.contains_invalid_mxid_characters(localpart):
257265 raise SynapseError(
258266 400,
259 "User ID must only contain characters which do not"
260 " require URL encoding."
267 "User ID can only contain characters a-z, 0-9, or '=_-./'",
261268 )
262269 user = UserID(localpart, self.hs.hostname)
263270 user_id = user.to_string()
286293 """
287294
288295 for c in threepidCreds:
289 logger.info("validating theeepidcred sid %s on id server %s",
296 logger.info("validating threepidcred sid %s on id server %s",
290297 c['sid'], c['idServer'])
291298 try:
292299 identity_handler = self.hs.get_handlers().identity_handler
293300 threepid = yield identity_handler.threepid_from_creds(c)
294 except:
301 except Exception:
295302 logger.exception("Couldn't validate 3pid")
296303 raise RegistrationError(400, "Couldn't validate 3pid")
297304
299306 raise RegistrationError(400, "Couldn't validate 3pid")
300307 logger.info("got threepid with medium '%s' and address '%s'",
301308 threepid['medium'], threepid['address'])
309
310 if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
311 raise RegistrationError(
312 403, "Third party identifier is not allowed"
313 )
302314
303315 @defer.inlineCallbacks
304316 def bind_emails(self, user_id, threepidCreds):
418430 create_profile_with_localpart=user.localpart,
419431 )
420432 else:
421 yield self.store.user_delete_access_tokens(user_id=user_id)
433 yield self._auth_handler.delete_access_tokens_for_user(user_id)
422434 yield self.store.add_access_token_to_user(user_id=user_id, token=token)
423435
424436 if displayname is not None:
433445 return self.hs.get_auth_handler()
434446
435447 @defer.inlineCallbacks
436 def guest_access_token_for(self, medium, address, inviter_user_id):
448 def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
449 """Get a guest access token for a 3PID, creating a guest account if
450 one doesn't already exist.
451
452 Args:
453 medium (str)
454 address (str)
455 inviter_user_id (str): The user ID who is trying to invite the
456 3PID
457
458 Returns:
459 Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
460 3PID guest account.
461 """
437462 access_token = yield self.store.get_3pid_guest_access_token(medium, address)
438463 if access_token:
439 defer.returnValue(access_token)
440
441 _, access_token = yield self.register(
464 user_info = yield self.auth.get_user_by_access_token(
465 access_token
466 )
467
468 defer.returnValue((user_info["user"].to_string(), access_token))
469
470 user_id, access_token = yield self.register(
442471 generate_token=True,
443472 make_guest=True
444473 )
445474 access_token = yield self.store.save_or_get_3pid_guest_access_token(
446475 medium, address, access_token, inviter_user_id
447476 )
448 defer.returnValue(access_token)
477
478 defer.returnValue((user_id, access_token))
00 # -*- coding: utf-8 -*-
11 # Copyright 2014 - 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
6364 super(RoomCreationHandler, self).__init__(hs)
6465
6566 self.spam_checker = hs.get_spam_checker()
67 self.event_creation_handler = hs.get_event_creation_handler()
6668
6769 @defer.inlineCallbacks
6870 def create_room(self, requester, config, ratelimit=True):
9092 if wchar in config["room_alias_name"]:
9193 raise SynapseError(400, "Invalid characters in room alias")
9294
93 room_alias = RoomAlias.create(
95 room_alias = RoomAlias(
9496 config["room_alias_name"],
9597 self.hs.hostname,
9698 )
107109 for i in invite_list:
108110 try:
109111 UserID.from_string(i)
110 except:
112 except Exception:
111113 raise SynapseError(400, "Invalid user_id: %s" % (i,))
112114
113115 invite_3pid_list = config.get("invite_3pid", [])
122124 while attempts < 5:
123125 try:
124126 random_string = stringutils.random_string(18)
125 gen_room_id = RoomID.create(
127 gen_room_id = RoomID(
126128 random_string,
127129 self.hs.hostname,
128130 )
162164
163165 creation_content = config.get("creation_content", {})
164166
165 msg_handler = self.hs.get_handlers().message_handler
166 room_member_handler = self.hs.get_handlers().room_member_handler
167 room_member_handler = self.hs.get_room_member_handler()
167168
168169 yield self._send_events_for_new_room(
169170 requester,
170171 room_id,
171 msg_handler,
172172 room_member_handler,
173173 preset_config=preset_config,
174174 invite_list=invite_list,
180180
181181 if "name" in config:
182182 name = config["name"]
183 yield msg_handler.create_and_send_nonmember_event(
183 yield self.event_creation_handler.create_and_send_nonmember_event(
184184 requester,
185185 {
186186 "type": EventTypes.Name,
193193
194194 if "topic" in config:
195195 topic = config["topic"]
196 yield msg_handler.create_and_send_nonmember_event(
196 yield self.event_creation_handler.create_and_send_nonmember_event(
197197 requester,
198198 {
199199 "type": EventTypes.Topic,
204204 },
205205 ratelimit=False)
206206
207 content = {}
208 is_direct = config.get("is_direct", None)
209 if is_direct:
210 content["is_direct"] = is_direct
211
212207 for invitee in invite_list:
208 content = {}
209 is_direct = config.get("is_direct", None)
210 if is_direct:
211 content["is_direct"] = is_direct
212
213213 yield room_member_handler.update_membership(
214214 requester,
215215 UserID.from_string(invitee),
223223 id_server = invite_3pid["id_server"]
224224 address = invite_3pid["address"]
225225 medium = invite_3pid["medium"]
226 yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
226 yield self.hs.get_room_member_handler().do_3pid_invite(
227227 room_id,
228228 requester.user,
229229 medium,
248248 self,
249249 creator, # A Requester object.
250250 room_id,
251 msg_handler,
252251 room_member_handler,
253252 preset_config,
254253 invite_list,
271270 @defer.inlineCallbacks
272271 def send(etype, content, **kwargs):
273272 event = create(etype, content, **kwargs)
274 yield msg_handler.create_and_send_nonmember_event(
273 yield self.event_creation_handler.create_and_send_nonmember_event(
275274 creator,
276275 event,
277276 ratelimit=False
475474 user.to_string()
476475 )
477476 if app_service:
478 events, end_key = yield self.store.get_appservice_room_stream(
479 service=app_service,
480 from_key=from_key,
481 to_key=to_key,
482 limit=limit,
483 )
477 # We no longer support AS users using /sync directly.
478 # See https://github.com/matrix-org/matrix-doc/issues/1144
479 raise NotImplementedError()
484480 else:
485481 room_events = yield self.store.get_membership_changes_for_user(
486482 user.to_string(), from_key, to_key
1919 from synapse.api.constants import (
2020 EventTypes, JoinRules,
2121 )
22 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
2223 from synapse.util.async import concurrently_execute
2324 from synapse.util.caches.descriptors import cachedInlineCallbacks
2425 from synapse.util.caches.response_cache import ResponseCache
6970 if search_filter:
7071 # We explicitly don't bother caching searches or requests for
7172 # appservice specific lists.
73 logger.info("Bypassing cache as search request.")
7274 return self._get_public_room_list(
7375 limit, since_token, search_filter, network_tuple=network_tuple,
7476 )
7678 key = (limit, since_token, network_tuple)
7779 result = self.response_cache.get(key)
7880 if not result:
81 logger.info("No cached result, calculating one.")
7982 result = self.response_cache.set(
8083 key,
81 self._get_public_room_list(
84 preserve_fn(self._get_public_room_list)(
8285 limit, since_token, network_tuple=network_tuple
8386 )
8487 )
85 return result
88 else:
89 logger.info("Using cached deferred result.")
90 return make_deferred_yieldable(result)
8691
8792 @defer.inlineCallbacks
8893 def _get_public_room_list(self, limit=None, since_token=None,
148153 # We want larger rooms to be first, hence negating num_joined_users
149154 rooms_to_order_value[room_id] = (-num_joined_users, room_id)
150155
156 logger.info("Getting ordering for %i rooms since %s",
157 len(room_ids), stream_token)
151158 yield concurrently_execute(get_order_for_room, room_ids, 10)
152159
153160 sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
175182 rooms_to_scan = rooms_to_scan[:since_token.current_limit]
176183 rooms_to_scan.reverse()
177184
178 # Actually generate the entries. _append_room_entry_to_chunk will append to
179 # chunk but will stop if len(chunk) > limit
185 logger.info("After sorting and filtering, %i rooms remain",
186 len(rooms_to_scan))
187
188 # _append_room_entry_to_chunk will append to chunk but will stop if
189 # len(chunk) > limit
190 #
191 # Normally we will generate enough results on the first iteration here,
192 # but if there is a search filter, _append_room_entry_to_chunk may
193 # filter some results out, in which case we loop again.
194 #
195 # We don't want to scan over the entire range either as that
196 # would potentially waste a lot of work.
197 #
198 # XXX if there is no limit, we may end up DoSing the server with
199 # calls to get_current_state_ids for every single room on the
200 # server. Surely we should cap this somehow?
201 #
202 if limit:
203 step = limit + 1
204 else:
205 # step cannot be zero
206 step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
207
180208 chunk = []
181 if limit and not search_filter:
182 step = limit + 1
183 for i in xrange(0, len(rooms_to_scan), step):
184 # We iterate here because the vast majority of cases we'll stop
185 # at first iteration, but occaisonally _append_room_entry_to_chunk
186 # won't append to the chunk and so we need to loop again.
187 # We don't want to scan over the entire range either as that
188 # would potentially waste a lot of work.
189 yield concurrently_execute(
190 lambda r: self._append_room_entry_to_chunk(
191 r, rooms_to_num_joined[r],
192 chunk, limit, search_filter
193 ),
194 rooms_to_scan[i:i + step], 10
195 )
196 if len(chunk) >= limit + 1:
197 break
198 else:
209 for i in xrange(0, len(rooms_to_scan), step):
210 batch = rooms_to_scan[i:i + step]
211 logger.info("Processing %i rooms for result", len(batch))
199212 yield concurrently_execute(
200213 lambda r: self._append_room_entry_to_chunk(
201214 r, rooms_to_num_joined[r],
202215 chunk, limit, search_filter
203216 ),
204 rooms_to_scan, 5
205 )
217 batch, 5,
218 )
219 logger.info("Now %i rooms in result", len(chunk))
220 if len(chunk) >= limit + 1:
221 break
206222
207223 chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
208224
392408 def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
393409 search_filter=None, include_all_networks=False,
394410 third_party_instance_id=None,):
395 repl_layer = self.hs.get_replication_layer()
411 repl_layer = self.hs.get_federation_client()
396412 if search_filter:
397413 # We can't cache when asking for search
398414 return repl_layer.get_public_rooms(
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
15
16 import abc
1617 import logging
1718
1819 from signedjson.key import decode_verify_key_bytes
2829 from synapse.types import UserID, RoomID
2930 from synapse.util.async import Linearizer
3031 from synapse.util.distributor import user_left_room, user_joined_room
31 from ._base import BaseHandler
32
3233
3334 logger = logging.getLogger(__name__)
3435
3536 id_server_scheme = "https://"
3637
3738
38 class RoomMemberHandler(BaseHandler):
39 class RoomMemberHandler(object):
3940 # TODO(paul): This handler currently contains a messy conflation of
4041 # low-level API that works on UserID objects and so on, and REST-level
4142 # API that takes ID strings and returns pagination chunks. These concerns
4243 # ought to be separated out a lot better.
4344
45 __metaclass__ = abc.ABCMeta
46
4447 def __init__(self, hs):
45 super(RoomMemberHandler, self).__init__(hs)
46
48 self.hs = hs
49 self.store = hs.get_datastore()
50 self.auth = hs.get_auth()
51 self.state_handler = hs.get_state_handler()
52 self.config = hs.config
53 self.simple_http_client = hs.get_simple_http_client()
54
55 self.federation_handler = hs.get_handlers().federation_handler
56 self.directory_handler = hs.get_handlers().directory_handler
57 self.registration_handler = hs.get_handlers().registration_handler
4758 self.profile_handler = hs.get_profile_handler()
59 self.event_creation_hander = hs.get_event_creation_handler()
4860
4961 self.member_linearizer = Linearizer(name="member")
5062
5163 self.clock = hs.get_clock()
5264 self.spam_checker = hs.get_spam_checker()
5365
54 self.distributor = hs.get_distributor()
55 self.distributor.declare("user_joined_room")
56 self.distributor.declare("user_left_room")
66 @abc.abstractmethod
67 def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
68 """Try and join a room that this server is not in
69
70 Args:
71 requester (Requester)
72 remote_room_hosts (list[str]): List of servers that can be used
73 to join via.
74 room_id (str): Room that we are trying to join
75 user (UserID): User who is trying to join
76 content (dict): A dict that should be used as the content of the
77 join event.
78
79 Returns:
80 Deferred
81 """
82 raise NotImplementedError()
83
84 @abc.abstractmethod
85 def _remote_reject_invite(self, remote_room_hosts, room_id, target):
86 """Attempt to reject an invite for a room this server is not in. If we
87 fail to do so we locally mark the invite as rejected.
88
89 Args:
90 requester (Requester)
91 remote_room_hosts (list[str]): List of servers to use to try and
92 reject invite
93 room_id (str)
94 target (UserID): The user rejecting the invite
95
96 Returns:
97 Deferred[dict]: A dictionary to be returned to the client, may
98 include event_id etc, or nothing if we locally rejected
99 """
100 raise NotImplementedError()
101
102 @abc.abstractmethod
103 def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
104 """Get a guest access token for a 3PID, creating a guest account if
105 one doesn't already exist.
106
107 Args:
108 requester (Requester)
109 medium (str)
110 address (str)
111 inviter_user_id (str): The user ID who is trying to invite the
112 3PID
113
114 Returns:
115 Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
116 3PID guest account.
117 """
118 raise NotImplementedError()
119
120 @abc.abstractmethod
121 def _user_joined_room(self, target, room_id):
122 """Notifies distributor on master process that the user has joined the
123 room.
124
125 Args:
126 target (UserID)
127 room_id (str)
128
129 Returns:
130 Deferred|None
131 """
132 raise NotImplementedError()
133
134 @abc.abstractmethod
135 def _user_left_room(self, target, room_id):
136 """Notifies distributor on master process that the user has left the
137 room.
138
139 Args:
140 target (UserID)
141 room_id (str)
142
143 Returns:
144 Deferred|None
145 """
146 raise NotImplementedError()
57147
58148 @defer.inlineCallbacks
59149 def _local_membership_update(
65155 ):
66156 if content is None:
67157 content = {}
68 msg_handler = self.hs.get_handlers().message_handler
69158
70159 content["membership"] = membership
71160 if requester.is_guest:
72161 content["kind"] = "guest"
73162
74 event, context = yield msg_handler.create_event(
163 event, context = yield self.event_creation_hander.create_event(
75164 requester,
76165 {
77166 "type": EventTypes.Member,
89178 )
90179
91180 # Check if this event matches the previous membership event for the user.
92 duplicate = yield msg_handler.deduplicate_state_event(event, context)
181 duplicate = yield self.event_creation_hander.deduplicate_state_event(
182 event, context,
183 )
93184 if duplicate is not None:
94185 # Discard the new event since this membership change is a no-op.
95186 defer.returnValue(duplicate)
96187
97 yield msg_handler.handle_new_client_event(
188 yield self.event_creation_hander.handle_new_client_event(
98189 requester,
99190 event,
100191 context,
116207 prev_member_event = yield self.store.get_event(prev_member_event_id)
117208 newly_joined = prev_member_event.membership != Membership.JOIN
118209 if newly_joined:
119 yield user_joined_room(self.distributor, target, room_id)
210 yield self._user_joined_room(target, room_id)
120211 elif event.membership == Membership.LEAVE:
121212 if prev_member_event_id:
122213 prev_member_event = yield self.store.get_event(prev_member_event_id)
123214 if prev_member_event.membership == Membership.JOIN:
124 user_left_room(self.distributor, target, room_id)
215 yield self._user_left_room(target, room_id)
125216
126217 defer.returnValue(event)
127
128 @defer.inlineCallbacks
129 def remote_join(self, remote_room_hosts, room_id, user, content):
130 if len(remote_room_hosts) == 0:
131 raise SynapseError(404, "No known servers")
132
133 # We don't do an auth check if we are doing an invite
134 # join dance for now, since we're kinda implicitly checking
135 # that we are allowed to join when we decide whether or not we
136 # need to do the invite/join dance.
137 yield self.hs.get_handlers().federation_handler.do_invite_join(
138 remote_room_hosts,
139 room_id,
140 user.to_string(),
141 content,
142 )
143 yield user_joined_room(self.distributor, user, room_id)
144218
145219 @defer.inlineCallbacks
146220 def update_membership(
188262 content_specified = bool(content)
189263 if content is None:
190264 content = {}
265 else:
266 # We do a copy here as we potentially change some keys
267 # later on.
268 content = dict(content)
191269
192270 effective_membership_state = action
193271 if action in ["kick", "unban"]:
196274 # if this is a join with a 3pid signature, we may need to turn a 3pid
197275 # invite into a normal invite before we can handle the join.
198276 if third_party_signed is not None:
199 replication = self.hs.get_replication_layer()
200 yield replication.exchange_third_party_invite(
277 yield self.federation_handler.exchange_third_party_invite(
201278 third_party_signed["sender"],
202279 target.to_string(),
203280 room_id,
218295 requester.user,
219296 )
220297 if not is_requester_admin:
221 if self.hs.config.block_non_admin_invites:
298 if self.config.block_non_admin_invites:
222299 logger.info(
223300 "Blocking invite: user is not admin and non-admin "
224301 "invites disabled"
277354 raise AuthError(403, "Guest access not allowed")
278355
279356 if not is_host_in_room:
280 inviter = yield self.get_inviter(target.to_string(), room_id)
357 inviter = yield self._get_inviter(target.to_string(), room_id)
281358 if inviter and not self.hs.is_mine(inviter):
282359 remote_room_hosts.append(inviter.domain)
283360
291368 if requester.is_guest:
292369 content["kind"] = "guest"
293370
294 ret = yield self.remote_join(
295 remote_room_hosts, room_id, target, content
371 ret = yield self._remote_join(
372 requester, remote_room_hosts, room_id, target, content
296373 )
297374 defer.returnValue(ret)
298375
299376 elif effective_membership_state == Membership.LEAVE:
300377 if not is_host_in_room:
301378 # perhaps we've been invited
302 inviter = yield self.get_inviter(target.to_string(), room_id)
379 inviter = yield self._get_inviter(target.to_string(), room_id)
303380 if not inviter:
304381 raise SynapseError(404, "Not a known room")
305382
313390 else:
314391 # send the rejection to the inviter's HS.
315392 remote_room_hosts = remote_room_hosts + [inviter.domain]
316 fed_handler = self.hs.get_handlers().federation_handler
317 try:
318 ret = yield fed_handler.do_remotely_reject_invite(
319 remote_room_hosts,
320 room_id,
321 target.to_string(),
322 )
323 defer.returnValue(ret)
324 except Exception as e:
325 # if we were unable to reject the exception, just mark
326 # it as rejected on our end and plough ahead.
327 #
328 # The 'except' clause is very broad, but we need to
329 # capture everything from DNS failures upwards
330 #
331 logger.warn("Failed to reject invite: %s", e)
332
333 yield self.store.locally_reject_invite(
334 target.to_string(), room_id
335 )
336
337 defer.returnValue({})
393 res = yield self._remote_reject_invite(
394 requester, remote_room_hosts, room_id, target,
395 )
396 defer.returnValue(res)
338397
339398 res = yield self._local_membership_update(
340399 requester=requester,
389448 else:
390449 requester = synapse.types.create_requester(target_user)
391450
392 message_handler = self.hs.get_handlers().message_handler
393 prev_event = yield message_handler.deduplicate_state_event(event, context)
451 prev_event = yield self.event_creation_hander.deduplicate_state_event(
452 event, context,
453 )
394454 if prev_event is not None:
395455 return
396456
407467 if is_blocked:
408468 raise SynapseError(403, "This room has been blocked on this server")
409469
410 yield message_handler.handle_new_client_event(
470 yield self.event_creation_hander.handle_new_client_event(
411471 requester,
412472 event,
413473 context,
429489 prev_member_event = yield self.store.get_event(prev_member_event_id)
430490 newly_joined = prev_member_event.membership != Membership.JOIN
431491 if newly_joined:
432 yield user_joined_room(self.distributor, target_user, room_id)
492 yield self._user_joined_room(target_user, room_id)
433493 elif event.membership == Membership.LEAVE:
434494 if prev_member_event_id:
435495 prev_member_event = yield self.store.get_event(prev_member_event_id)
436496 if prev_member_event.membership == Membership.JOIN:
437 user_left_room(self.distributor, target_user, room_id)
497 yield self._user_left_room(target_user, room_id)
438498
439499 @defer.inlineCallbacks
440500 def _can_guest_join(self, current_state_ids):
468528 Raises:
469529 SynapseError if room alias could not be found.
470530 """
471 directory_handler = self.hs.get_handlers().directory_handler
531 directory_handler = self.directory_handler
472532 mapping = yield directory_handler.get_association(room_alias)
473533
474534 if not mapping:
480540 defer.returnValue((RoomID.from_string(room_id), servers))
481541
482542 @defer.inlineCallbacks
483 def get_inviter(self, user_id, room_id):
543 def _get_inviter(self, user_id, room_id):
484544 invite = yield self.store.get_invite_for_user_in_room(
485545 user_id=user_id,
486546 room_id=room_id,
499559 requester,
500560 txn_id
501561 ):
502 if self.hs.config.block_non_admin_invites:
562 if self.config.block_non_admin_invites:
503563 is_requester_admin = yield self.auth.is_server_admin(
504564 requester.user,
505565 )
546606 str: the matrix ID of the 3pid, or None if it is not recognized.
547607 """
548608 try:
549 data = yield self.hs.get_simple_http_client().get_json(
609 data = yield self.simple_http_client.get_json(
550610 "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
551611 {
552612 "medium": medium,
557617 if "mxid" in data:
558618 if "signatures" not in data:
559619 raise AuthError(401, "No signatures on 3pid binding")
560 self.verify_any_signature(data, id_server)
620 yield self._verify_any_signature(data, id_server)
561621 defer.returnValue(data["mxid"])
562622
563623 except IOError as e:
565625 defer.returnValue(None)
566626
567627 @defer.inlineCallbacks
568 def verify_any_signature(self, data, server_hostname):
628 def _verify_any_signature(self, data, server_hostname):
569629 if server_hostname not in data["signatures"]:
570630 raise AuthError(401, "No signature from server %s" % (server_hostname,))
571631 for key_name, signature in data["signatures"][server_hostname].items():
572 key_data = yield self.hs.get_simple_http_client().get_json(
632 key_data = yield self.simple_http_client.get_json(
573633 "%s%s/_matrix/identity/api/v1/pubkey/%s" %
574634 (id_server_scheme, server_hostname, key_name,),
575635 )
594654 user,
595655 txn_id
596656 ):
597 room_state = yield self.hs.get_state_handler().get_current_state(room_id)
657 room_state = yield self.state_handler.get_current_state(room_id)
598658
599659 inviter_display_name = ""
600660 inviter_avatar_url = ""
625685
626686 token, public_keys, fallback_public_key, display_name = (
627687 yield self._ask_id_server_for_third_party_invite(
688 requester=requester,
628689 id_server=id_server,
629690 medium=medium,
630691 address=address,
639700 )
640701 )
641702
642 msg_handler = self.hs.get_handlers().message_handler
643 yield msg_handler.create_and_send_nonmember_event(
703 yield self.event_creation_hander.create_and_send_nonmember_event(
644704 requester,
645705 {
646706 "type": EventTypes.ThirdPartyInvite,
662722 @defer.inlineCallbacks
663723 def _ask_id_server_for_third_party_invite(
664724 self,
725 requester,
665726 id_server,
666727 medium,
667728 address,
678739 Asks an identity server for a third party invite.
679740
680741 Args:
742 requester (Requester)
681743 id_server (str): hostname + optional port for the identity server.
682744 medium (str): The literal string "email".
683745 address (str): The third party address being invited.
719781 "sender_avatar_url": inviter_avatar_url,
720782 }
721783
722 if self.hs.config.invite_3pid_guest:
723 registration_handler = self.hs.get_handlers().registration_handler
724 guest_access_token = yield registration_handler.guest_access_token_for(
784 if self.config.invite_3pid_guest:
785 guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
786 requester=requester,
725787 medium=medium,
726788 address=address,
727789 inviter_user_id=inviter_user_id,
728790 )
729791
730 guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
731 guest_access_token
732 )
733
734792 invite_config.update({
735793 "guest_access_token": guest_access_token,
736 "guest_user_id": guest_user_info["user"].to_string(),
794 "guest_user_id": guest_user_id,
737795 })
738796
739 data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
797 data = yield self.simple_http_client.post_urlencoded_get_json(
740798 is_url,
741799 invite_config
742800 )
759817 defer.returnValue((token, public_keys, fallback_public_key, display_name))
760818
761819 @defer.inlineCallbacks
820 def _is_host_in_room(self, current_state_ids):
821 # Have we just created the room, and is this about to be the very
822 # first member event?
823 create_event_id = current_state_ids.get(("m.room.create", ""))
824 if len(current_state_ids) == 1 and create_event_id:
825 defer.returnValue(self.hs.is_mine_id(create_event_id))
826
827 for etype, state_key in current_state_ids:
828 if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
829 continue
830
831 event_id = current_state_ids[(etype, state_key)]
832 event = yield self.store.get_event(event_id, allow_none=True)
833 if not event:
834 continue
835
836 if event.membership == Membership.JOIN:
837 defer.returnValue(True)
838
839 defer.returnValue(False)
840
841
842 class RoomMemberMasterHandler(RoomMemberHandler):
843 def __init__(self, hs):
844 super(RoomMemberMasterHandler, self).__init__(hs)
845
846 self.distributor = hs.get_distributor()
847 self.distributor.declare("user_joined_room")
848 self.distributor.declare("user_left_room")
849
850 @defer.inlineCallbacks
851 def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
852 """Implements RoomMemberHandler._remote_join
853 """
854 if len(remote_room_hosts) == 0:
855 raise SynapseError(404, "No known servers")
856
857 # We don't do an auth check if we are doing an invite
858 # join dance for now, since we're kinda implicitly checking
859 # that we are allowed to join when we decide whether or not we
860 # need to do the invite/join dance.
861 yield self.federation_handler.do_invite_join(
862 remote_room_hosts,
863 room_id,
864 user.to_string(),
865 content,
866 )
867 yield self._user_joined_room(user, room_id)
868
869 @defer.inlineCallbacks
870 def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
871 """Implements RoomMemberHandler._remote_reject_invite
872 """
873 fed_handler = self.federation_handler
874 try:
875 ret = yield fed_handler.do_remotely_reject_invite(
876 remote_room_hosts,
877 room_id,
878 target.to_string(),
879 )
880 defer.returnValue(ret)
881 except Exception as e:
882 # if we were unable to reject the exception, just mark
883 # it as rejected on our end and plough ahead.
884 #
885 # The 'except' clause is very broad, but we need to
886 # capture everything from DNS failures upwards
887 #
888 logger.warn("Failed to reject invite: %s", e)
889
890 yield self.store.locally_reject_invite(
891 target.to_string(), room_id
892 )
893 defer.returnValue({})
894
895 def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
896 """Implements RoomMemberHandler.get_or_register_3pid_guest
897 """
898 rg = self.registration_handler
899 return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
900
901 def _user_joined_room(self, target, room_id):
902 """Implements RoomMemberHandler._user_joined_room
903 """
904 return user_joined_room(self.distributor, target, room_id)
905
906 def _user_left_room(self, target, room_id):
907 """Implements RoomMemberHandler._user_left_room
908 """
909 return user_left_room(self.distributor, target, room_id)
910
911 @defer.inlineCallbacks
762912 def forget(self, user, room_id):
763913 user_id = user.to_string()
764914
778928
779929 if membership:
780930 yield self.store.forget(user_id, room_id)
781
782 @defer.inlineCallbacks
783 def _is_host_in_room(self, current_state_ids):
784 # Have we just created the room, and is this about to be the very
785 # first member event?
786 create_event_id = current_state_ids.get(("m.room.create", ""))
787 if len(current_state_ids) == 1 and create_event_id:
788 defer.returnValue(self.hs.is_mine_id(create_event_id))
789
790 for etype, state_key in current_state_ids:
791 if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
792 continue
793
794 event_id = current_state_ids[(etype, state_key)]
795 event = yield self.store.get_event(event_id, allow_none=True)
796 if not event:
797 continue
798
799 if event.membership == Membership.JOIN:
800 defer.returnValue(True)
801
802 defer.returnValue(False)
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import logging
16
17 from twisted.internet import defer
18
19 from synapse.api.errors import SynapseError
20 from synapse.handlers.room_member import RoomMemberHandler
21 from synapse.replication.http.membership import (
22 remote_join, remote_reject_invite, get_or_register_3pid_guest,
23 notify_user_membership_change,
24 )
25
26
27 logger = logging.getLogger(__name__)
28
29
30 class RoomMemberWorkerHandler(RoomMemberHandler):
31 @defer.inlineCallbacks
32 def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
33 """Implements RoomMemberHandler._remote_join
34 """
35 if len(remote_room_hosts) == 0:
36 raise SynapseError(404, "No known servers")
37
38 ret = yield remote_join(
39 self.simple_http_client,
40 host=self.config.worker_replication_host,
41 port=self.config.worker_replication_http_port,
42 requester=requester,
43 remote_room_hosts=remote_room_hosts,
44 room_id=room_id,
45 user_id=user.to_string(),
46 content=content,
47 )
48
49 yield self._user_joined_room(user, room_id)
50
51 defer.returnValue(ret)
52
53 def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
54 """Implements RoomMemberHandler._remote_reject_invite
55 """
56 return remote_reject_invite(
57 self.simple_http_client,
58 host=self.config.worker_replication_host,
59 port=self.config.worker_replication_http_port,
60 requester=requester,
61 remote_room_hosts=remote_room_hosts,
62 room_id=room_id,
63 user_id=target.to_string(),
64 )
65
66 def _user_joined_room(self, target, room_id):
67 """Implements RoomMemberHandler._user_joined_room
68 """
69 return notify_user_membership_change(
70 self.simple_http_client,
71 host=self.config.worker_replication_host,
72 port=self.config.worker_replication_http_port,
73 user_id=target.to_string(),
74 room_id=room_id,
75 change="joined",
76 )
77
78 def _user_left_room(self, target, room_id):
79 """Implements RoomMemberHandler._user_left_room
80 """
81 return notify_user_membership_change(
82 self.simple_http_client,
83 host=self.config.worker_replication_host,
84 port=self.config.worker_replication_http_port,
85 user_id=target.to_string(),
86 room_id=room_id,
87 change="left",
88 )
89
90 def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
91 """Implements RoomMemberHandler.get_or_register_3pid_guest
92 """
93 return get_or_register_3pid_guest(
94 self.simple_http_client,
95 host=self.config.worker_replication_host,
96 port=self.config.worker_replication_http_port,
97 requester=requester,
98 medium=medium,
99 address=address,
100 inviter_user_id=inviter_user_id,
101 )
6060 assert batch_group is not None
6161 assert batch_group_key is not None
6262 assert batch_token is not None
63 except:
63 except Exception:
6464 raise SynapseError(400, "Invalid batch")
6565
6666 try:
0 # -*- coding: utf-8 -*-
1 # Copyright 2017 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import logging
15
16 from twisted.internet import defer
17
18 from synapse.api.errors import Codes, StoreError, SynapseError
19 from ._base import BaseHandler
20
21 logger = logging.getLogger(__name__)
22
23
24 class SetPasswordHandler(BaseHandler):
25 """Handler which deals with changing user account passwords"""
26 def __init__(self, hs):
27 super(SetPasswordHandler, self).__init__(hs)
28 self._auth_handler = hs.get_auth_handler()
29 self._device_handler = hs.get_device_handler()
30
31 @defer.inlineCallbacks
32 def set_password(self, user_id, newpassword, requester=None):
33 password_hash = yield self._auth_handler.hash(newpassword)
34
35 except_device_id = requester.device_id if requester else None
36 except_access_token_id = requester.access_token_id if requester else None
37
38 try:
39 yield self.store.user_set_password_hash(user_id, password_hash)
40 except StoreError as e:
41 if e.code == 404:
42 raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
43 raise e
44
45 # we want to log out all of the user's other sessions. First delete
46 # all his other devices.
47 yield self._device_handler.delete_all_devices_for_user(
48 user_id, except_device_id=except_device_id,
49 )
50
51 # and now delete any access tokens which weren't associated with
52 # devices (or were associated with this device).
53 yield self._auth_handler.delete_access_tokens_for_user(
54 user_id, except_token_id=except_access_token_id,
55 )
1414
1515 from synapse.api.constants import Membership, EventTypes
1616 from synapse.util.async import concurrently_execute
17 from synapse.util.logcontext import LoggingContext
17 from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
1818 from synapse.util.metrics import Measure, measure_func
1919 from synapse.util.caches.response_cache import ResponseCache
2020 from synapse.push.clientformat import format_push_rules_for_user
183183 if not result:
184184 result = self.response_cache.set(
185185 sync_config.request_key,
186 self._wait_for_sync_for_user(
186 preserve_fn(self._wait_for_sync_for_user)(
187187 sync_config, since_token, timeout, full_state
188188 )
189189 )
190 return result
190 return make_deferred_yieldable(result)
191191
192192 @defer.inlineCallbacks
193193 def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
234234 defer.returnValue(rules)
235235
236236 @defer.inlineCallbacks
237 def ephemeral_by_room(self, sync_config, now_token, since_token=None):
237 def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
238238 """Get the ephemeral events for each room the user is in
239239 Args:
240 sync_config (SyncConfig): The flags, filters and user for the sync.
240 sync_result_builder(SyncResultBuilder)
241241 now_token (StreamToken): Where the server is currently up to.
242242 since_token (StreamToken): Where the server was when the client
243243 last synced.
247247 typing events for that room.
248248 """
249249
250 sync_config = sync_result_builder.sync_config
251
250252 with Measure(self.clock, "ephemeral_by_room"):
251253 typing_key = since_token.typing_key if since_token else "0"
252254
253 room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
255 room_ids = sync_result_builder.joined_room_ids
254256
255257 typing_source = self.event_sources.sources["typing"]
256258 typing, typing_key = yield typing_source.get_new_events(
564566 # Always use the `now_token` in `SyncResultBuilder`
565567 now_token = yield self.event_sources.get_current_token()
566568
569 user_id = sync_config.user.to_string()
570 app_service = self.store.get_app_service_by_user_id(user_id)
571 if app_service:
572 # We no longer support AS users using /sync directly.
573 # See https://github.com/matrix-org/matrix-doc/issues/1144
574 raise NotImplementedError()
575 else:
576 joined_room_ids = yield self.get_rooms_for_user_at(
577 user_id, now_token.room_stream_id,
578 )
579
567580 sync_result_builder = SyncResultBuilder(
568581 sync_config, full_state,
569582 since_token=since_token,
570583 now_token=now_token,
584 joined_room_ids=joined_room_ids,
571585 )
572586
573587 account_data_by_room = yield self._generate_sync_entry_for_account_data(
602616 device_id = sync_config.device_id
603617 one_time_key_counts = {}
604618 if device_id:
605 user_id = sync_config.user.to_string()
606619 one_time_key_counts = yield self.store.count_e2e_one_time_keys(
607620 user_id, device_id
608621 )
890903 ephemeral_by_room = {}
891904 else:
892905 now_token, ephemeral_by_room = yield self.ephemeral_by_room(
893 sync_result_builder.sync_config,
906 sync_result_builder,
894907 now_token=sync_result_builder.now_token,
895908 since_token=sync_result_builder.since_token,
896909 )
9951008 if rooms_changed:
9961009 defer.returnValue(True)
9971010
998 app_service = self.store.get_app_service_by_user_id(user_id)
999 if app_service:
1000 rooms = yield self.store.get_app_service_rooms(app_service)
1001 joined_room_ids = set(r.room_id for r in rooms)
1002 else:
1003 joined_room_ids = yield self.store.get_rooms_for_user(user_id)
1004
10051011 stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
1006 for room_id in joined_room_ids:
1012 for room_id in sync_result_builder.joined_room_ids:
10071013 if self.store.has_room_changed_since(room_id, stream_id):
10081014 defer.returnValue(True)
10091015 defer.returnValue(False)
10261032 sync_config = sync_result_builder.sync_config
10271033
10281034 assert since_token
1029
1030 app_service = self.store.get_app_service_by_user_id(user_id)
1031 if app_service:
1032 rooms = yield self.store.get_app_service_rooms(app_service)
1033 joined_room_ids = set(r.room_id for r in rooms)
1034 else:
1035 joined_room_ids = yield self.store.get_rooms_for_user(user_id)
10361035
10371036 # Get a list of membership change events that have happened.
10381037 rooms_changed = yield self.store.get_membership_changes_for_user(
10561055 # we do send down the room, and with full state, where necessary
10571056
10581057 old_state_ids = None
1059 if room_id in joined_room_ids and non_joins:
1058 if room_id in sync_result_builder.joined_room_ids and non_joins:
10601059 # Always include if the user (re)joined the room, especially
10611060 # important so that device list changes are calculated correctly.
10621061 # If there are non join member events, but we are still in the room,
10661065 # User is in the room so we don't need to do the invite/leave checks
10671066 continue
10681067
1069 if room_id in joined_room_ids or has_join:
1068 if room_id in sync_result_builder.joined_room_ids or has_join:
10701069 old_state_ids = yield self.get_state_at(room_id, since_token)
10711070 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
10721071 old_mem_ev = None
10781077 newly_joined_rooms.append(room_id)
10791078
10801079 # If user is in the room then we don't need to do the invite/leave checks
1081 if room_id in joined_room_ids:
1080 if room_id in sync_result_builder.joined_room_ids:
10821081 continue
10831082
10841083 if not non_joins:
11451144
11461145 # Get all events for rooms we're currently joined to.
11471146 room_to_events = yield self.store.get_room_events_stream_for_rooms(
1148 room_ids=joined_room_ids,
1147 room_ids=sync_result_builder.joined_room_ids,
11491148 from_key=since_token.room_key,
11501149 to_key=now_token.room_key,
11511150 limit=timeline_limit + 1,
11531152
11541153 # We loop through all room ids, even if there are no new events, in case
11551154 # there are non room events taht we need to notify about.
1156 for room_id in joined_room_ids:
1155 for room_id in sync_result_builder.joined_room_ids:
11571156 room_entry = room_to_events.get(room_id, None)
11581157
11591158 if room_entry:
13611360 else:
13621361 raise Exception("Unrecognized rtype: %r", room_builder.rtype)
13631362
1363 @defer.inlineCallbacks
1364 def get_rooms_for_user_at(self, user_id, stream_ordering):
1365 """Get set of joined rooms for a user at the given stream ordering.
1366
1367 The stream ordering *must* be recent, otherwise this may throw an
1368 exception if older than a month. (This function is called with the
1369 current token, which should be perfectly fine).
1370
1371 Args:
1372 user_id (str)
1373 stream_ordering (int)
1374
1375 ReturnValue:
1376 Deferred[frozenset[str]]: Set of room_ids the user is in at given
1377 stream_ordering.
1378 """
1379 joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
1380 user_id,
1381 )
1382
1383 joined_room_ids = set()
1384
1385 # We need to check that the stream ordering of the join for each room
1386 # is before the stream_ordering asked for. This might not be the case
1387 # if the user joins a room between us getting the current token and
1388 # calling `get_rooms_for_user_with_stream_ordering`.
1389 # If the membership's stream ordering is after the given stream
1390 # ordering, we need to go and work out if the user was in the room
1391 # before.
1392 for room_id, membership_stream_ordering in joined_rooms:
1393 if membership_stream_ordering <= stream_ordering:
1394 joined_room_ids.add(room_id)
1395 continue
1396
1397 logger.info("User joined room after current token: %s", room_id)
1398
1399 extrems = yield self.store.get_forward_extremeties_for_room(
1400 room_id, stream_ordering,
1401 )
1402 users_in_room = yield self.state.get_current_user_in_room(
1403 room_id, extrems,
1404 )
1405 if user_id in users_in_room:
1406 joined_room_ids.add(room_id)
1407
1408 joined_room_ids = frozenset(joined_room_ids)
1409 defer.returnValue(joined_room_ids)
1410
13641411
13651412 def _action_has_highlight(actions):
13661413 for action in actions:
14101457
14111458 class SyncResultBuilder(object):
14121459 "Used to help build up a new SyncResult for a user"
1413 def __init__(self, sync_config, full_state, since_token, now_token):
1460 def __init__(self, sync_config, full_state, since_token, now_token,
1461 joined_room_ids):
14141462 """
14151463 Args:
14161464 sync_config(SyncConfig)
14221470 self.full_state = full_state
14231471 self.since_token = since_token
14241472 self.now_token = now_token
1473 self.joined_room_ids = joined_room_ids
14251474
14261475 self.presence = []
14271476 self.account_data = []
5555
5656 self.federation = hs.get_federation_sender()
5757
58 hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
58 hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
5959
6060 hs.get_distributor().observe("user_left_room", self.user_left_room)
6161
1919 from synapse.storage.roommember import ProfileInfo
2020 from synapse.util.metrics import Measure
2121 from synapse.util.async import sleep
22 from synapse.types import get_localpart_from_id
2223
2324
2425 logger = logging.getLogger(__name__)
2526
2627
27 class UserDirectoyHandler(object):
28 class UserDirectoryHandler(object):
2829 """Handles querying of and keeping updated the user_directory.
2930
3031 N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
4041 one public room.
4142 """
4243
43 INITIAL_SLEEP_MS = 50
44 INITIAL_SLEEP_COUNT = 100
45 INITIAL_BATCH_SIZE = 100
44 INITIAL_ROOM_SLEEP_MS = 50
45 INITIAL_ROOM_SLEEP_COUNT = 100
46 INITIAL_ROOM_BATCH_SIZE = 100
47 INITIAL_USER_SLEEP_MS = 10
4648
4749 def __init__(self, hs):
4850 self.store = hs.get_datastore()
5254 self.notifier = hs.get_notifier()
5355 self.is_mine_id = hs.is_mine_id
5456 self.update_user_directory = hs.config.update_user_directory
57 self.search_all_users = hs.config.user_directory_search_all_users
5558
5659 # When start up for the first time we need to populate the user_directory.
5760 # This is a set of user_id's we've inserted already
110113 self._is_processing = False
111114
112115 @defer.inlineCallbacks
116 def handle_local_profile_change(self, user_id, profile):
117 """Called to update index of our local user profiles when they change
118 irrespective of any rooms the user may be in.
119 """
120 yield self.store.update_profile_in_user_dir(
121 user_id, profile.display_name, profile.avatar_url, None,
122 )
123
124 @defer.inlineCallbacks
113125 def _unsafe_process(self):
114126 # If self.pos is None then means we haven't fetched it from DB
115127 if self.pos is None:
147159 room_ids = yield self.store.get_all_rooms()
148160
149161 logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
150 num_processed_rooms = 1
162 num_processed_rooms = 0
151163
152164 for room_id in room_ids:
153 logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
154 yield self._handle_intial_room(room_id)
165 logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
166 yield self._handle_initial_room(room_id)
155167 num_processed_rooms += 1
156 yield sleep(self.INITIAL_SLEEP_MS / 1000.)
168 yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
157169
158170 logger.info("Processed all rooms.")
171
172 if self.search_all_users:
173 num_processed_users = 0
174 user_ids = yield self.store.get_all_local_users()
175 logger.info("Doing initial update of user directory. %d users", len(user_ids))
176 for user_id in user_ids:
177 # We add profiles for all users even if they don't match the
178 # include pattern, just in case we want to change it in future
179 logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
180 yield self._handle_local_user(user_id)
181 num_processed_users += 1
182 yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
183
184 logger.info("Processed all users")
159185
160186 self.initially_handled_users = None
161187 self.initially_handled_users_in_public = None
165191 yield self.store.update_user_directory_stream_pos(new_pos)
166192
167193 @defer.inlineCallbacks
168 def _handle_intial_room(self, room_id):
194 def _handle_initial_room(self, room_id):
169195 """Called when we initially fill out user_directory one room at a time
170196 """
171197 is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
200226 to_update = set()
201227 count = 0
202228 for user_id in user_ids:
203 if count % self.INITIAL_SLEEP_COUNT == 0:
204 yield sleep(self.INITIAL_SLEEP_MS / 1000.)
229 if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
230 yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
205231
206232 if not self.is_mine_id(user_id):
207233 count += 1
215241 if user_id == other_user_id:
216242 continue
217243
218 if count % self.INITIAL_SLEEP_COUNT == 0:
219 yield sleep(self.INITIAL_SLEEP_MS / 1000.)
244 if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
245 yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
220246 count += 1
221247
222248 user_set = (user_id, other_user_id)
236262 else:
237263 self.initially_handled_users_share_private_room.add(user_set)
238264
239 if len(to_insert) > self.INITIAL_BATCH_SIZE:
265 if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
240266 yield self.store.add_users_who_share_room(
241267 room_id, not is_public, to_insert,
242268 )
243269 to_insert.clear()
244270
245 if len(to_update) > self.INITIAL_BATCH_SIZE:
271 if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
246272 yield self.store.update_users_who_share_room(
247273 room_id, not is_public, to_update,
248274 )
384410 yield self._handle_remove_user(room_id, user_id)
385411
386412 @defer.inlineCallbacks
413 def _handle_local_user(self, user_id):
414 """Adds a new local roomless user into the user_directory_search table.
415 Used to populate up the user index when we have an
416 user_directory_search_all_users specified.
417 """
418 logger.debug("Adding new local user to dir, %r", user_id)
419
420 profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
421
422 row = yield self.store.get_user_in_directory(user_id)
423 if not row:
424 yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
425
426 @defer.inlineCallbacks
387427 def _handle_new_user(self, room_id, user_id, profile):
388428 """Called when we might need to add user to directory
389429
390430 Args:
391 room_id (str): room_id that user joined or started being public that
431 room_id (str): room_id that user joined or started being public
392432 user_id (str)
393433 """
394 logger.debug("Adding user to dir, %r", user_id)
434 logger.debug("Adding new user to dir, %r", user_id)
395435
396436 row = yield self.store.get_user_in_directory(user_id)
397437 if not row:
406446 if not row:
407447 yield self.store.add_users_to_public_room(room_id, [user_id])
408448 else:
409 logger.debug("Not adding user to public dir, %r", user_id)
449 logger.debug("Not adding new user to public dir, %r", user_id)
410450
411451 # Now we update users who share rooms with users. We do this by getting
412452 # all the current users in the room and seeing which aren't already
0 # -*- coding: utf-8 -*-
1 # Copyright 2017 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from synapse.http.server import wrap_request_handler
16 from twisted.web.resource import Resource
17 from twisted.web.server import NOT_DONE_YET
18
19
20 class AdditionalResource(Resource):
21 """Resource wrapper for additional_resources
22
23 If the user has configured additional_resources, we need to wrap the
24 handler class with a Resource so that we can map it into the resource tree.
25
26 This class is also where we wrap the request handler with logging, metrics,
27 and exception handling.
28 """
29 def __init__(self, hs, handler):
30 """Initialise AdditionalResource
31
32 The ``handler`` should return a deferred which completes when it has
33 done handling the request. It should write a response with
34 ``request.write()``, and call ``request.finish()``.
35
36 Args:
37 hs (synapse.server.HomeServer): homeserver
38 handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
39 function to be called to handle the request.
40 """
41 Resource.__init__(self)
42 self._handler = handler
43
44 # these are required by the request_handler wrapper
45 self.version_string = hs.version_string
46 self.clock = hs.get_clock()
47
48 def render(self, request):
49 self._async_render(request)
50 return NOT_DONE_YET
51
52 @wrap_request_handler
53 def _async_render(self, request):
54 return self._handler(request)
1717 from synapse.api.errors import (
1818 CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
1919 )
20 from synapse.util.logcontext import preserve_context_over_fn
20 from synapse.util.caches import CACHE_SIZE_FACTOR
21 from synapse.util.logcontext import make_deferred_yieldable
2122 from synapse.util import logcontext
2223 import synapse.metrics
2324 from synapse.http.endpoint import SpiderEndpoint
2930 from twisted.web.client import (
3031 BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
3132 readBody, PartialDownloadError,
33 HTTPConnectionPool,
3234 )
3335 from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
3436 from twisted.web.http import PotentialDataLoss
6365 """
6466 def __init__(self, hs):
6567 self.hs = hs
68
69 pool = HTTPConnectionPool(reactor)
70
71 # the pusher makes lots of concurrent SSL connections to sygnal, and
72 # tends to do so in batches, so we need to allow the pool to keep lots
73 # of idle connections around.
74 pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
75 pool.cachedConnectionTimeout = 2 * 60
76
6677 # The default context factory in Twisted 14.0.0 (which we require) is
6778 # BrowserLikePolicyForHTTPS which will do regular cert validation
6879 # 'like a browser'
6980 self.agent = Agent(
7081 reactor,
7182 connectTimeout=15,
72 contextFactory=hs.get_http_client_context_factory()
83 contextFactory=hs.get_http_client_context_factory(),
84 pool=pool,
7385 )
7486 self.user_agent = hs.version_string
7587 self.clock = hs.get_clock()
113125 raise e
114126
115127 @defer.inlineCallbacks
116 def post_urlencoded_get_json(self, uri, args={}):
128 def post_urlencoded_get_json(self, uri, args={}, headers=None):
129 """
130 Args:
131 uri (str):
132 args (dict[str, str|List[str]]): query params
133 headers (dict[str, List[str]]|None): If not None, a map from
134 header name to a list of values for that header
135
136 Returns:
137 Deferred[object]: parsed json
138 """
139
117140 # TODO: Do we ever want to log message contents?
118141 logger.debug("post_urlencoded_get_json args: %s", args)
119142
120143 query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
144
145 actual_headers = {
146 b"Content-Type": [b"application/x-www-form-urlencoded"],
147 b"User-Agent": [self.user_agent],
148 }
149 if headers:
150 actual_headers.update(headers)
121151
122152 response = yield self.request(
123153 "POST",
124154 uri.encode("ascii"),
125 headers=Headers({
126 b"Content-Type": [b"application/x-www-form-urlencoded"],
127 b"User-Agent": [self.user_agent],
128 }),
155 headers=Headers(actual_headers),
129156 bodyProducer=FileBodyProducer(StringIO(query_bytes))
130157 )
131158
132 body = yield preserve_context_over_fn(readBody, response)
159 body = yield make_deferred_yieldable(readBody(response))
133160
134161 defer.returnValue(json.loads(body))
135162
136163 @defer.inlineCallbacks
137 def post_json_get_json(self, uri, post_json):
164 def post_json_get_json(self, uri, post_json, headers=None):
165 """
166
167 Args:
168 uri (str):
169 post_json (object):
170 headers (dict[str, List[str]]|None): If not None, a map from
171 header name to a list of values for that header
172
173 Returns:
174 Deferred[object]: parsed json
175 """
138176 json_str = encode_canonical_json(post_json)
139177
140178 logger.debug("HTTP POST %s -> %s", json_str, uri)
179
180 actual_headers = {
181 b"Content-Type": [b"application/json"],
182 b"User-Agent": [self.user_agent],
183 }
184 if headers:
185 actual_headers.update(headers)
141186
142187 response = yield self.request(
143188 "POST",
144189 uri.encode("ascii"),
145 headers=Headers({
146 b"Content-Type": [b"application/json"],
147 b"User-Agent": [self.user_agent],
148 }),
190 headers=Headers(actual_headers),
149191 bodyProducer=FileBodyProducer(StringIO(json_str))
150192 )
151193
152 body = yield preserve_context_over_fn(readBody, response)
194 body = yield make_deferred_yieldable(readBody(response))
153195
154196 if 200 <= response.code < 300:
155197 defer.returnValue(json.loads(body))
159201 defer.returnValue(json.loads(body))
160202
161203 @defer.inlineCallbacks
162 def get_json(self, uri, args={}):
204 def get_json(self, uri, args={}, headers=None):
163205 """ Gets some json from the given URI.
164206
165207 Args:
168210 None.
169211 **Note**: The value of each key is assumed to be an iterable
170212 and *not* a string.
213 headers (dict[str, List[str]]|None): If not None, a map from
214 header name to a list of values for that header
171215 Returns:
172216 Deferred: Succeeds when we get *any* 2xx HTTP response, with the
173217 HTTP body as JSON.
176220 error message.
177221 """
178222 try:
179 body = yield self.get_raw(uri, args)
223 body = yield self.get_raw(uri, args, headers=headers)
180224 defer.returnValue(json.loads(body))
181225 except CodeMessageException as e:
182226 raise self._exceptionFromFailedRequest(e.code, e.msg)
183227
184228 @defer.inlineCallbacks
185 def put_json(self, uri, json_body, args={}):
229 def put_json(self, uri, json_body, args={}, headers=None):
186230 """ Puts some json to the given URI.
187231
188232 Args:
192236 None.
193237 **Note**: The value of each key is assumed to be an iterable
194238 and *not* a string.
239 headers (dict[str, List[str]]|None): If not None, a map from
240 header name to a list of values for that header
195241 Returns:
196242 Deferred: Succeeds when we get *any* 2xx HTTP response, with the
197243 HTTP body as JSON.
204250
205251 json_str = encode_canonical_json(json_body)
206252
253 actual_headers = {
254 b"Content-Type": [b"application/json"],
255 b"User-Agent": [self.user_agent],
256 }
257 if headers:
258 actual_headers.update(headers)
259
207260 response = yield self.request(
208261 "PUT",
209262 uri.encode("ascii"),
210 headers=Headers({
211 b"User-Agent": [self.user_agent],
212 "Content-Type": ["application/json"]
213 }),
263 headers=Headers(actual_headers),
214264 bodyProducer=FileBodyProducer(StringIO(json_str))
215265 )
216266
217 body = yield preserve_context_over_fn(readBody, response)
267 body = yield make_deferred_yieldable(readBody(response))
218268
219269 if 200 <= response.code < 300:
220270 defer.returnValue(json.loads(body))
225275 raise CodeMessageException(response.code, body)
226276
227277 @defer.inlineCallbacks
228 def get_raw(self, uri, args={}):
278 def get_raw(self, uri, args={}, headers=None):
229279 """ Gets raw text from the given URI.
230280
231281 Args:
234284 None.
235285 **Note**: The value of each key is assumed to be an iterable
236286 and *not* a string.
287 headers (dict[str, List[str]]|None): If not None, a map from
288 header name to a list of values for that header
237289 Returns:
238290 Deferred: Succeeds when we get *any* 2xx HTTP response, with the
239291 HTTP body at text.
245297 query_bytes = urllib.urlencode(args, True)
246298 uri = "%s?%s" % (uri, query_bytes)
247299
300 actual_headers = {
301 b"User-Agent": [self.user_agent],
302 }
303 if headers:
304 actual_headers.update(headers)
305
248306 response = yield self.request(
249307 "GET",
250308 uri.encode("ascii"),
251 headers=Headers({
252 b"User-Agent": [self.user_agent],
253 })
254 )
255
256 body = yield preserve_context_over_fn(readBody, response)
309 headers=Headers(actual_headers),
310 )
311
312 body = yield make_deferred_yieldable(readBody(response))
257313
258314 if 200 <= response.code < 300:
259315 defer.returnValue(body)
273329 # The two should be factored out.
274330
275331 @defer.inlineCallbacks
276 def get_file(self, url, output_stream, max_size=None):
332 def get_file(self, url, output_stream, max_size=None, headers=None):
277333 """GETs a file from a given URL
278334 Args:
279335 url (str): The URL to GET
280336 output_stream (file): File to write the response body to.
337 headers (dict[str, List[str]]|None): If not None, a map from
338 header name to a list of values for that header
281339 Returns:
282340 A (int,dict,string,int) tuple of the file length, dict of the response
283341 headers, absolute URI of the response and HTTP response code.
284342 """
285343
344 actual_headers = {
345 b"User-Agent": [self.user_agent],
346 }
347 if headers:
348 actual_headers.update(headers)
349
286350 response = yield self.request(
287351 "GET",
288352 url.encode("ascii"),
289 headers=Headers({
290 b"User-Agent": [self.user_agent],
291 })
292 )
293
294 headers = dict(response.headers.getAllRawHeaders())
295
296 if 'Content-Length' in headers and headers['Content-Length'] > max_size:
353 headers=Headers(actual_headers),
354 )
355
356 resp_headers = dict(response.headers.getAllRawHeaders())
357
358 if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
297359 logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
298360 raise SynapseError(
299361 502,
314376 # straight back in again
315377
316378 try:
317 length = yield preserve_context_over_fn(
318 _readBodyToFile,
319 response, output_stream, max_size
320 )
379 length = yield make_deferred_yieldable(_readBodyToFile(
380 response, output_stream, max_size,
381 ))
321382 except Exception as e:
322383 logger.exception("Failed to download body")
323384 raise SynapseError(
326387 Codes.UNKNOWN,
327388 )
328389
329 defer.returnValue((length, headers, response.request.absoluteURI, response.code))
390 defer.returnValue(
391 (length, resp_headers, response.request.absoluteURI, response.code),
392 )
330393
331394
332395 # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
394457 )
395458
396459 try:
397 body = yield preserve_context_over_fn(readBody, response)
460 body = yield make_deferred_yieldable(readBody(response))
398461 defer.returnValue(body)
399462 except PartialDownloadError as e:
400463 # twisted dislikes google's response, no content length.
356356 def eb(res, record_type):
357357 if res.check(DNSNameError):
358358 return []
359 logger.warn("Error looking up %s for %s: %s",
360 record_type, host, res, res.value)
359 logger.warn("Error looking up %s for %s: %s", record_type, host, res)
361360 return res
362361
363362 # no logcontexts here, so we can safely fire these off and gatherResults
364 d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
365 d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
363 d1 = dns_client.lookupAddress(host).addCallbacks(
364 cb, eb, errbackArgs=("A", ))
365 d2 = dns_client.lookupIPV6Address(host).addCallbacks(
366 cb, eb, errbackArgs=("AAAA", ))
366367 results = yield defer.DeferredList(
367368 [d1, d2], consumeErrors=True)
368369
2626 from canonicaljson import encode_canonical_json
2727
2828 from synapse.api.errors import (
29 SynapseError, Codes, HttpResponseException,
29 SynapseError, Codes, HttpResponseException, FederationDeniedError,
3030 )
3131
3232 from signedjson.sign import sign_json
122122
123123 Fails with ``HTTPRequestException``: if we get an HTTP response
124124 code >= 300.
125
125126 Fails with ``NotRetryingDestination`` if we are not yet ready
126127 to retry this server.
128
129 Fails with ``FederationDeniedError`` if this destination
130 is not on our federation whitelist
131
127132 (May also fail with plenty of other Exceptions for things like DNS
128133 failures, connection failures, SSL failures.)
129134 """
135 if (
136 self.hs.config.federation_domain_whitelist and
137 destination not in self.hs.config.federation_domain_whitelist
138 ):
139 raise FederationDeniedError(destination)
140
130141 limiter = yield synapse.util.retryutils.get_retry_limiter(
131142 destination,
132143 self.clock,
307318
308319 Fails with ``NotRetryingDestination`` if we are not yet ready
309320 to retry this server.
321
322 Fails with ``FederationDeniedError`` if this destination
323 is not on our federation whitelist
310324 """
311325
312326 if not json_data_callback:
367381
368382 Fails with ``NotRetryingDestination`` if we are not yet ready
369383 to retry this server.
384
385 Fails with ``FederationDeniedError`` if this destination
386 is not on our federation whitelist
370387 """
371388
372389 def body_callback(method, url_bytes, headers_dict):
421438
422439 Fails with ``NotRetryingDestination`` if we are not yet ready
423440 to retry this server.
441
442 Fails with ``FederationDeniedError`` if this destination
443 is not on our federation whitelist
424444 """
425445 logger.debug("get_json args: %s", args)
426446
474494
475495 Fails with ``NotRetryingDestination`` if we are not yet ready
476496 to retry this server.
497
498 Fails with ``FederationDeniedError`` if this destination
499 is not on our federation whitelist
477500 """
478501
479502 response = yield self._request(
517540
518541 Fails with ``NotRetryingDestination`` if we are not yet ready
519542 to retry this server.
543
544 Fails with ``FederationDeniedError`` if this destination
545 is not on our federation whitelist
520546 """
521547
522548 encoded_args = {}
549575 length = yield _readBodyToFile(
550576 response, output_stream, max_size
551577 )
552 except:
578 except Exception:
553579 logger.exception("Failed to download body")
554580 raise
555581
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
2728 )
2829
2930 from twisted.internet import defer
31 from twisted.python import failure
3032 from twisted.web import server, resource
3133 from twisted.web.server import NOT_DONE_YET
3234 from twisted.web.util import redirectTo
3436 import collections
3537 import logging
3638 import urllib
37 import ujson
39 import simplejson
3840
3941 logger = logging.getLogger(__name__)
4042
4143 metrics = synapse.metrics.get_metrics_for(__name__)
4244
43 incoming_requests_counter = metrics.register_counter(
44 "requests",
45 # total number of responses served, split by method/servlet/tag
46 response_count = metrics.register_counter(
47 "response_count",
4548 labels=["method", "servlet", "tag"],
46 )
49 alternative_names=(
50 # the following are all deprecated aliases for the same metric
51 metrics.name_prefix + x for x in (
52 "_requests",
53 "_response_time:count",
54 "_response_ru_utime:count",
55 "_response_ru_stime:count",
56 "_response_db_txn_count:count",
57 "_response_db_txn_duration:count",
58 )
59 )
60 )
61
62 requests_counter = metrics.register_counter(
63 "requests_received",
64 labels=["method", "servlet", ],
65 )
66
4767 outgoing_responses_counter = metrics.register_counter(
4868 "responses",
4969 labels=["method", "code"],
5070 )
5171
52 response_timer = metrics.register_distribution(
53 "response_time",
54 labels=["method", "servlet", "tag"]
55 )
56
57 response_ru_utime = metrics.register_distribution(
58 "response_ru_utime", labels=["method", "servlet", "tag"]
59 )
60
61 response_ru_stime = metrics.register_distribution(
62 "response_ru_stime", labels=["method", "servlet", "tag"]
63 )
64
65 response_db_txn_count = metrics.register_distribution(
66 "response_db_txn_count", labels=["method", "servlet", "tag"]
67 )
68
69 response_db_txn_duration = metrics.register_distribution(
70 "response_db_txn_duration", labels=["method", "servlet", "tag"]
71 )
72
72 response_timer = metrics.register_counter(
73 "response_time_seconds",
74 labels=["method", "servlet", "tag"],
75 alternative_names=(
76 metrics.name_prefix + "_response_time:total",
77 ),
78 )
79
80 response_ru_utime = metrics.register_counter(
81 "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
82 alternative_names=(
83 metrics.name_prefix + "_response_ru_utime:total",
84 ),
85 )
86
87 response_ru_stime = metrics.register_counter(
88 "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
89 alternative_names=(
90 metrics.name_prefix + "_response_ru_stime:total",
91 ),
92 )
93
94 response_db_txn_count = metrics.register_counter(
95 "response_db_txn_count", labels=["method", "servlet", "tag"],
96 alternative_names=(
97 metrics.name_prefix + "_response_db_txn_count:total",
98 ),
99 )
100
101 # seconds spent waiting for db txns, excluding scheduling time, when processing
102 # this request
103 response_db_txn_duration = metrics.register_counter(
104 "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
105 alternative_names=(
106 metrics.name_prefix + "_response_db_txn_duration:total",
107 ),
108 )
109
110 # seconds spent waiting for a db connection, when processing this request
111 response_db_sched_duration = metrics.register_counter(
112 "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
113 )
73114
74115 _next_request_id = 0
75116
105146 with LoggingContext(request_id) as request_context:
106147 with Measure(self.clock, "wrapped_request_handler"):
107148 request_metrics = RequestMetrics()
108 request_metrics.start(self.clock, name=self.__class__.__name__)
149 # we start the request metrics timer here with an initial stab
150 # at the servlet name. For most requests that name will be
151 # JsonResource (or a subclass), and JsonResource._async_render
152 # will update it once it picks a servlet.
153 servlet_name = self.__class__.__name__
154 request_metrics.start(self.clock, name=servlet_name)
109155
110156 request_context.request = request_id
111157 with request.processing():
114160 if include_metrics:
115161 yield request_handler(self, request, request_metrics)
116162 else:
163 requests_counter.inc(request.method, servlet_name)
117164 yield request_handler(self, request)
118165 except CodeMessageException as e:
119166 code = e.code
129176 pretty_print=_request_user_agent_is_curl(request),
130177 version_string=self.version_string,
131178 )
132 except:
133 logger.exception(
134 "Failed handle request %s.%s on %r: %r",
179 except Exception:
180 # failure.Failure() fishes the original Failure out
181 # of our stack, and thus gives us a sensible stack
182 # trace.
183 f = failure.Failure()
184 logger.error(
185 "Failed handle request %s.%s on %r: %r: %s",
135186 request_handler.__module__,
136187 request_handler.__name__,
137188 self,
138 request
189 request,
190 f.getTraceback().rstrip(),
139191 )
140192 respond_with_json(
141193 request,
184236 """ This implements the HttpServer interface and provides JSON support for
185237 Resources.
186238
187 Register callbacks via register_path()
239 Register callbacks via register_paths()
188240
189241 Callbacks can return a tuple of status code and a dict in which case the
190242 the dict will automatically be sent to the client as a JSON object.
231283 This checks if anyone has registered a callback for that method and
232284 path.
233285 """
286 callback, group_dict = self._get_handler_for_request(request)
287
288 servlet_instance = getattr(callback, "__self__", None)
289 if servlet_instance is not None:
290 servlet_classname = servlet_instance.__class__.__name__
291 else:
292 servlet_classname = "%r" % callback
293
294 request_metrics.name = servlet_classname
295 requests_counter.inc(request.method, servlet_classname)
296
297 # Now trigger the callback. If it returns a response, we send it
298 # here. If it throws an exception, that is handled by the wrapper
299 # installed by @request_handler.
300
301 kwargs = intern_dict({
302 name: urllib.unquote(value).decode("UTF-8") if value else value
303 for name, value in group_dict.items()
304 })
305
306 callback_return = yield callback(request, **kwargs)
307 if callback_return is not None:
308 code, response = callback_return
309 self._send_response(request, code, response)
310
311 def _get_handler_for_request(self, request):
312 """Finds a callback method to handle the given request
313
314 Args:
315 request (twisted.web.http.Request):
316
317 Returns:
318 Tuple[Callable, dict[str, str]]: callback method, and the dict
319 mapping keys to path components as specified in the handler's
320 path match regexp.
321
322 The callback will normally be a method registered via
323 register_paths, so will return (possibly via Deferred) either
324 None, or a tuple of (http code, response body).
325 """
234326 if request.method == "OPTIONS":
235 self._send_response(request, 200, {})
236 return
327 return _options_handler, {}
237328
238329 # Loop through all the registered callbacks to check if the method
239330 # and path regex match
240331 for path_entry in self.path_regexs.get(request.method, []):
241332 m = path_entry.pattern.match(request.path)
242 if not m:
243 continue
244
245 # We found a match! Trigger callback and then return the
246 # returned response. We pass both the request and any
247 # matched groups from the regex to the callback.
248
249 callback = path_entry.callback
250
251 kwargs = intern_dict({
252 name: urllib.unquote(value).decode("UTF-8") if value else value
253 for name, value in m.groupdict().items()
254 })
255
256 callback_return = yield callback(request, **kwargs)
257 if callback_return is not None:
258 code, response = callback_return
259 self._send_response(request, code, response)
260
261 servlet_instance = getattr(callback, "__self__", None)
262 if servlet_instance is not None:
263 servlet_classname = servlet_instance.__class__.__name__
264 else:
265 servlet_classname = "%r" % callback
266
267 request_metrics.name = servlet_classname
268
269 return
333 if m:
334 # We found a match!
335 return path_entry.callback, m.groupdict()
270336
271337 # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
272 raise UnrecognizedRequestError()
338 return _unrecognised_request_handler, {}
273339
274340 def _send_response(self, request, code, response_json_object,
275341 response_code_message=None):
276 # could alternatively use request.notifyFinish() and flip a flag when
277 # the Deferred fires, but since the flag is RIGHT THERE it seems like
278 # a waste.
279 if request._disconnected:
280 logger.warn(
281 "Not sending response to request %s, already disconnected.",
282 request)
283 return
284
285342 outgoing_responses_counter.inc(request.method, str(code))
286343
287344 # TODO: Only enable CORS for the requests that need it.
295352 )
296353
297354
355 def _options_handler(request):
356 """Request handler for OPTIONS requests
357
358 This is a request handler suitable for return from
359 _get_handler_for_request. It returns a 200 and an empty body.
360
361 Args:
362 request (twisted.web.http.Request):
363
364 Returns:
365 Tuple[int, dict]: http code, response body.
366 """
367 return 200, {}
368
369
370 def _unrecognised_request_handler(request):
371 """Request handler for unrecognised requests
372
373 This is a request handler suitable for return from
374 _get_handler_for_request. It actually just raises an
375 UnrecognizedRequestError.
376
377 Args:
378 request (twisted.web.http.Request):
379 """
380 raise UnrecognizedRequestError()
381
382
298383 class RequestMetrics(object):
299384 def start(self, clock, name):
300385 self.start = clock.time_msec()
315400 )
316401 return
317402
318 incoming_requests_counter.inc(request.method, self.name, tag)
403 response_count.inc(request.method, self.name, tag)
319404
320405 response_timer.inc_by(
321406 clock.time_msec() - self.start, request.method,
334419 context.db_txn_count, request.method, self.name, tag
335420 )
336421 response_db_txn_duration.inc_by(
337 context.db_txn_duration, request.method, self.name, tag
422 context.db_txn_duration_ms / 1000., request.method, self.name, tag
423 )
424 response_db_sched_duration.inc_by(
425 context.db_sched_duration_ms / 1000., request.method, self.name, tag
338426 )
339427
340428
357445 def respond_with_json(request, code, json_object, send_cors=False,
358446 response_code_message=None, pretty_print=False,
359447 version_string="", canonical_json=True):
448 # could alternatively use request.notifyFinish() and flip a flag when
449 # the Deferred fires, but since the flag is RIGHT THERE it seems like
450 # a waste.
451 if request._disconnected:
452 logger.warn(
453 "Not sending response to request %s, already disconnected.",
454 request)
455 return
456
360457 if pretty_print:
361458 json_bytes = encode_pretty_printed_json(json_object) + "\n"
362459 else:
363460 if canonical_json or synapse.events.USE_FROZEN_DICTS:
364461 json_bytes = encode_canonical_json(json_object)
365462 else:
366 # ujson doesn't like frozen_dicts.
367 json_bytes = ujson.dumps(json_object, ensure_ascii=False)
463 json_bytes = simplejson.dumps(json_object)
368464
369465 return respond_with_json_bytes(
370466 request, code, json_bytes,
4747 if name in args:
4848 try:
4949 return int(args[name][0])
50 except:
50 except Exception:
5151 message = "Query parameter %r must be an integer" % (name,)
5252 raise SynapseError(400, message)
5353 else:
8787 "true": True,
8888 "false": False,
8989 }[args[name][0]]
90 except:
90 except Exception:
9191 message = (
9292 "Boolean query parameter %r must be one of"
9393 " ['true', 'false']"
147147 return default
148148
149149
150 def parse_json_value_from_request(request):
150 def parse_json_value_from_request(request, allow_empty_body=False):
151151 """Parse a JSON value from the body of a twisted HTTP request.
152152
153153 Args:
154154 request: the twisted HTTP request.
155 allow_empty_body (bool): if True, an empty body will be accepted and
156 turned into None
155157
156158 Returns:
157159 The JSON value.
161163 """
162164 try:
163165 content_bytes = request.content.read()
164 except:
166 except Exception:
165167 raise SynapseError(400, "Error reading JSON content.")
168
169 if not content_bytes and allow_empty_body:
170 return None
166171
167172 try:
168173 content = simplejson.loads(content_bytes)
169 except simplejson.JSONDecodeError:
174 except Exception as e:
175 logger.warn("Unable to parse JSON: %s", e)
170176 raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
171177
172178 return content
173179
174180
175 def parse_json_object_from_request(request):
181 def parse_json_object_from_request(request, allow_empty_body=False):
176182 """Parse a JSON object from the body of a twisted HTTP request.
177183
178184 Args:
179185 request: the twisted HTTP request.
186 allow_empty_body (bool): if True, an empty body will be accepted and
187 turned into an empty dict.
180188
181189 Raises:
182190 SynapseError if the request body couldn't be decoded as JSON or
183191 if it wasn't a JSON object.
184192 """
185 content = parse_json_value_from_request(request)
193 content = parse_json_value_from_request(
194 request, allow_empty_body=allow_empty_body,
195 )
196
197 if allow_empty_body and content is None:
198 return {}
186199
187200 if type(content) != dict:
188201 message = "Content must be a JSON object."
6565 context = LoggingContext.current_context()
6666 ru_utime, ru_stime = context.get_resource_usage()
6767 db_txn_count = context.db_txn_count
68 db_txn_duration = context.db_txn_duration
69 except:
68 db_txn_duration_ms = context.db_txn_duration_ms
69 db_sched_duration_ms = context.db_sched_duration_ms
70 except Exception:
7071 ru_utime, ru_stime = (0, 0)
71 db_txn_count, db_txn_duration = (0, 0)
72 db_txn_count, db_txn_duration_ms = (0, 0)
7273
7374 self.site.access_logger.info(
7475 "%s - %s - {%s}"
75 " Processed request: %dms (%dms, %dms) (%dms/%d)"
76 " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
7677 " %sB %s \"%s %s %s\" \"%s\"",
7778 self.getClientIP(),
7879 self.site.site_tag,
8081 int(time.time() * 1000) - self.start_time,
8182 int(ru_utime * 1000),
8283 int(ru_stime * 1000),
83 int(db_txn_duration * 1000),
84 db_sched_duration_ms,
85 db_txn_duration_ms,
8486 int(db_txn_count),
8587 self.sentLength,
8688 self.code,
5656 return metric
5757
5858 def register_counter(self, *args, **kwargs):
59 """
60 Returns:
61 CounterMetric
62 """
5963 return self._register(CounterMetric, *args, **kwargs)
6064
6165 def register_callback(self, *args, **kwargs):
66 """
67 Returns:
68 CallbackMetric
69 """
6270 return self._register(CallbackMetric, *args, **kwargs)
6371
6472 def register_distribution(self, *args, **kwargs):
73 """
74 Returns:
75 DistributionMetric
76 """
6577 return self._register(DistributionMetric, *args, **kwargs)
6678
6779 def register_cache(self, *args, **kwargs):
80 """
81 Returns:
82 CacheMetric
83 """
6884 return self._register(CacheMetric, *args, **kwargs)
6985
7086
145161 num_pending += 1
146162
147163 num_pending += len(reactor.threadCallQueue)
148
149164 start = time.time() * 1000
150165 ret = func(*args, **kwargs)
151166 end = time.time() * 1000
167
168 # record the amount of wallclock time spent running pending calls.
169 # This is a proxy for the actual amount of time between reactor polls,
170 # since about 25% of time is actually spent running things triggered by
171 # I/O events, but that is harder to capture without rewriting half the
172 # reactor.
152173 tick_time.inc_by(end - start)
153174 pending_calls_metric.inc_by(num_pending)
154175
1414
1515
1616 from itertools import chain
17
18
19 # TODO(paul): I can't believe Python doesn't have one of these
20 def map_concat(func, items):
21 # flatten a list-of-lists
22 return list(chain.from_iterable(map(func, items)))
17 import logging
18
19 logger = logging.getLogger(__name__)
20
21
22 def flatten(items):
23 """Flatten a list of lists
24
25 Args:
26 items: iterable[iterable[X]]
27
28 Returns:
29 list[X]: flattened list
30 """
31 return list(chain.from_iterable(items))
2332
2433
2534 class BaseMetric(object):
26
27 def __init__(self, name, labels=[]):
28 self.name = name
35 """Base class for metrics which report a single value per label set
36 """
37
38 def __init__(self, name, labels=[], alternative_names=[]):
39 """
40 Args:
41 name (str): principal name for this metric
42 labels (list(str)): names of the labels which will be reported
43 for this metric
44 alternative_names (iterable(str)): list of alternative names for
45 this metric. This can be useful to provide a migration path
46 when renaming metrics.
47 """
48 self._names = [name] + list(alternative_names)
2949 self.labels = labels # OK not to clone as we never write it
3050
3151 def dimension(self):
3555 return not len(self.labels)
3656
3757 def _render_labelvalue(self, value):
38 # TODO: some kind of value escape
58 # TODO: escape backslashes, quotes and newlines
3959 return '"%s"' % (value)
4060
4161 def _render_key(self, values):
4666 for k, v in zip(self.labels, values)])
4767 )
4868
69 def _render_for_labels(self, label_values, value):
70 """Render this metric for a single set of labels
71
72 Args:
73 label_values (list[str]): values for each of the labels
74 value: value of the metric at with these labels
75
76 Returns:
77 iterable[str]: rendered metric
78 """
79 rendered_labels = self._render_key(label_values)
80 return (
81 "%s%s %.12g" % (name, rendered_labels, value)
82 for name in self._names
83 )
84
85 def render(self):
86 """Render this metric
87
88 Each metric is rendered as:
89
90 name{label1="val1",label2="val2"} value
91
92 https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
93
94 Returns:
95 iterable[str]: rendered metrics
96 """
97 raise NotImplementedError()
98
4999
50100 class CounterMetric(BaseMetric):
51101 """The simplest kind of metric; one that stores a monotonically-increasing
52 integer that counts events."""
102 value that counts events or running totals.
103
104 Example use cases for Counters:
105 - Number of requests processed
106 - Number of items that were inserted into a queue
107 - Total amount of data that a system has processed
108 Counters can only go up (and be reset when the process restarts).
109 """
53110
54111 def __init__(self, *args, **kwargs):
55112 super(CounterMetric, self).__init__(*args, **kwargs)
56113
114 # dict[list[str]]: value for each set of label values. the keys are the
115 # label values, in the same order as the labels in self.labels.
116 #
117 # (if the metric is a scalar, the (single) key is the empty list).
57118 self.counts = {}
58119
59120 # Scalar metrics are never empty
60121 if self.is_scalar():
61 self.counts[()] = 0
122 self.counts[()] = 0.
62123
63124 def inc_by(self, incr, *values):
64125 if len(values) != self.dimension():
76137 def inc(self, *values):
77138 self.inc_by(1, *values)
78139
79 def render_item(self, k):
80 return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
81
82 def render(self):
83 return map_concat(self.render_item, sorted(self.counts.keys()))
140 def render(self):
141 return flatten(
142 self._render_for_labels(k, self.counts[k])
143 for k in sorted(self.counts.keys())
144 )
84145
85146
86147 class CallbackMetric(BaseMetric):
94155 self.callback = callback
95156
96157 def render(self):
97 value = self.callback()
158 try:
159 value = self.callback()
160 except Exception:
161 logger.exception("Failed to render %s", self.name)
162 return ["# FAILED to render " + self.name]
98163
99164 if self.is_scalar():
100 return ["%s %.12g" % (self.name, value)]
101
102 return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
103 for k in sorted(value.keys())]
165 return list(self._render_for_labels([], value))
166
167 return flatten(
168 self._render_for_labels(k, value[k])
169 for k in sorted(value.keys())
170 )
104171
105172
106173 class DistributionMetric(object):
125192
126193
127194 class CacheMetric(object):
128 __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
195 __slots__ = (
196 "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
197 )
129198
130199 def __init__(self, name, size_callback, cache_name):
131200 self.name = name
133202
134203 self.hits = 0
135204 self.misses = 0
205 self.evicted_size = 0
136206
137207 self.size_callback = size_callback
138208
141211
142212 def inc_misses(self):
143213 self.misses += 1
214
215 def inc_evictions(self, size=1):
216 self.evicted_size += size
144217
145218 def render(self):
146219 size = self.size_callback()
151224 """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
152225 """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
153226 """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
227 """%s:evicted_size{name="%s"} %d""" % (
228 self.name, self.cache_name, self.evicted_size
229 ),
154230 ]
155231
156232
0 # -*- coding: utf-8 -*-
1 # Copyright 2017 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 from twisted.internet import defer
15
16 from synapse.types import UserID
17
18
19 class ModuleApi(object):
20 """A proxy object that gets passed to password auth providers so they
21 can register new users etc if necessary.
22 """
23 def __init__(self, hs, auth_handler):
24 self.hs = hs
25
26 self._store = hs.get_datastore()
27 self._auth = hs.get_auth()
28 self._auth_handler = auth_handler
29
30 def get_user_by_req(self, req, allow_guest=False):
31 """Check the access_token provided for a request
32
33 Args:
34 req (twisted.web.server.Request): Incoming HTTP request
35 allow_guest (bool): True if guest users should be allowed. If this
36 is False, and the access token is for a guest user, an
37 AuthError will be thrown
38 Returns:
39 twisted.internet.defer.Deferred[synapse.types.Requester]:
40 the requester for this request
41 Raises:
42 synapse.api.errors.AuthError: if no user by that token exists,
43 or the token is invalid.
44 """
45 return self._auth.get_user_by_req(req, allow_guest)
46
47 def get_qualified_user_id(self, username):
48 """Qualify a user id, if necessary
49
50 Takes a user id provided by the user and adds the @ and :domain to
51 qualify it, if necessary
52
53 Args:
54 username (str): provided user id
55
56 Returns:
57 str: qualified @user:id
58 """
59 if username.startswith('@'):
60 return username
61 return UserID(username, self.hs.hostname).to_string()
62
63 def check_user_exists(self, user_id):
64 """Check if user exists.
65
66 Args:
67 user_id (str): Complete @user:id
68
69 Returns:
70 Deferred[str|None]: Canonical (case-corrected) user_id, or None
71 if the user is not registered.
72 """
73 return self._auth_handler.check_user_exists(user_id)
74
75 def register(self, localpart):
76 """Registers a new user with given localpart
77
78 Returns:
79 Deferred: a 2-tuple of (user_id, access_token)
80 """
81 reg = self.hs.get_handlers().registration_handler
82 return reg.register(localpart=localpart)
83
84 @defer.inlineCallbacks
85 def invalidate_access_token(self, access_token):
86 """Invalidate an access token for a user
87
88 Args:
89 access_token(str): access token
90
91 Returns:
92 twisted.internet.defer.Deferred - resolves once the access token
93 has been removed.
94
95 Raises:
96 synapse.api.errors.AuthError: the access token is invalid
97 """
98 # see if the access token corresponds to a device
99 user_info = yield self._auth.get_user_by_access_token(access_token)
100 device_id = user_info.get("device_id")
101 user_id = user_info["user"].to_string()
102 if device_id:
103 # delete the device, which will also delete its access tokens
104 yield self.hs.get_device_handler().delete_device(user_id, device_id)
105 else:
106 # no associated device. Just delete the access token.
107 yield self._auth_handler.delete_access_token(access_token)
108
109 def run_db_interaction(self, desc, func, *args, **kwargs):
110 """Run a function with a database connection
111
112 Args:
113 desc (str): description for the transaction, for metrics etc
114 func (func): function to be run. Passed a database cursor object
115 as well as *args and **kwargs
116 *args: positional args to be passed to func
117 **kwargs: named args to be passed to func
118
119 Returns:
120 Deferred[object]: result of func
121 """
122 return self._store.runInteraction(desc, func, *args, **kwargs)
254254 )
255255
256256 if self.federation_sender:
257 preserve_fn(self.federation_sender.notify_new_events)(
258 room_stream_id
259 )
257 self.federation_sender.notify_new_events(room_stream_id)
260258
261259 if event.type == EventTypes.Member and event.membership == Membership.JOIN:
262260 self._user_joined_room(event.state_key, event.room_id)
288286 for user_stream in user_streams:
289287 try:
290288 user_stream.notify(stream_key, new_token, time_now_ms)
291 except:
289 except Exception:
292290 logger.exception("Failed to notify listener")
293291
294292 self.notify_replication()
296294 def on_new_replication_data(self):
297295 """Used to inform replication listeners that something has happend
298296 without waking up any of the normal user event streams"""
299 with PreserveLoggingContext():
300 self.notify_replication()
297 self.notify_replication()
301298
302299 @defer.inlineCallbacks
303300 def wait_for_events(self, user_id, timeout, callback, room_ids=None,
515512 self.replication_deferred = ObservableDeferred(defer.Deferred())
516513 deferred.callback(None)
517514
518 for cb in self.replication_callbacks:
519 preserve_fn(cb)()
515 # the callbacks may well outlast the current request, so we run
516 # them in the sentinel logcontext.
517 #
518 # (ideally it would be up to the callbacks to know if they were
519 # starting off background processes and drop the logcontext
520 # accordingly, but that requires more changes)
521 for cb in self.replication_callbacks:
522 cb()
520523
521524 @defer.inlineCallbacks
522525 def wait_for_replication(self, callback, timeout):
3939 @defer.inlineCallbacks
4040 def handle_push_actions_for_event(self, event, context):
4141 with Measure(self.clock, "action_for_event_by_user"):
42 actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
42 yield self.bulk_evaluator.action_for_event_by_user(
4343 event, context
4444 )
45
46 context.push_actions = [
47 (uid, actions) for uid, actions in actions_by_user.iteritems()
48 ]
136136
137137 @defer.inlineCallbacks
138138 def action_for_event_by_user(self, event, context):
139 """Given an event and context, evaluate the push rules and return
140 the results
139 """Given an event and context, evaluate the push rules and insert the
140 results into the event_push_actions_staging table.
141141
142142 Returns:
143 dict of user_id -> action
143 Deferred
144144 """
145145 rules_by_user = yield self._get_rules_for_event(event, context)
146146 actions_by_user = {}
189189 if matches:
190190 actions = [x for x in rule['actions'] if x != 'dont_notify']
191191 if actions and 'notify' in actions:
192 # Push rules say we should notify the user of this event
192193 actions_by_user[uid] = actions
193194 break
194 defer.returnValue(actions_by_user)
195
196 # Mark in the DB staging area the push actions for users who should be
197 # notified for this event. (This will then get handled when we persist
198 # the event)
199 yield self.store.add_push_actions_to_staging(
200 event.event_id, actions_by_user,
201 )
195202
196203
197204 def _condition_checker(evaluator, conditions, uid, display_name, cache):
120120 starting_max_ordering = self.max_stream_ordering
121121 try:
122122 yield self._unsafe_process()
123 except:
123 except Exception:
124124 logger.exception("Exception processing notifs")
125125 if self.max_stream_ordering == starting_max_ordering:
126126 break
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2017 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1112 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
14
15 from synapse.push import PusherConfigException
15 import logging
1616
1717 from twisted.internet import defer, reactor
1818 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
1919
20 import logging
2120 import push_rule_evaluator
2221 import push_tools
23
22 import synapse
23 from synapse.push import PusherConfigException
2424 from synapse.util.logcontext import LoggingContext
2525 from synapse.util.metrics import Measure
2626
2727 logger = logging.getLogger(__name__)
28
29 metrics = synapse.metrics.get_metrics_for(__name__)
30
31 http_push_processed_counter = metrics.register_counter(
32 "http_pushes_processed",
33 )
34
35 http_push_failed_counter = metrics.register_counter(
36 "http_pushes_failed",
37 )
2838
2939
3040 class HttpPusher(object):
130140 starting_max_ordering = self.max_stream_ordering
131141 try:
132142 yield self._unsafe_process()
133 except:
143 except Exception:
134144 logger.exception("Exception processing notifs")
135145 if self.max_stream_ordering == starting_max_ordering:
136146 break
150160 self.user_id, self.last_stream_ordering, self.max_stream_ordering
151161 )
152162
163 logger.info(
164 "Processing %i unprocessed push actions for %s starting at "
165 "stream_ordering %s",
166 len(unprocessed), self.name, self.last_stream_ordering,
167 )
168
153169 for push_action in unprocessed:
154170 processed = yield self._process_one(push_action)
155171 if processed:
172 http_push_processed_counter.inc()
156173 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
157174 self.last_stream_ordering = push_action['stream_ordering']
158175 yield self.store.update_pusher_last_stream_ordering_and_success(
167184 self.failing_since
168185 )
169186 else:
187 http_push_failed_counter.inc()
170188 if not self.failing_since:
171189 self.failing_since = self.clock.time_msec()
172190 yield self.store.update_pusher_failing_since(
294312 if event.type == 'm.room.member':
295313 d['notification']['membership'] = event.content['membership']
296314 d['notification']['user_is_target'] = event.state_key == self.user_id
297 if not self.hs.config.push_redact_content and 'content' in event:
315 if self.hs.config.push_include_content and 'content' in event:
298316 d['notification']['content'] = event.content
299317
300318 # We no longer send aliases separately, instead, we send the human
313331 defer.returnValue([])
314332 try:
315333 resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
316 except:
317 logger.warn("Failed to push %s ", self.url)
334 except Exception:
335 logger.warn(
336 "Failed to push event %s to %s",
337 event.event_id, self.name, exc_info=True,
338 )
318339 defer.returnValue(False)
319340 rejected = []
320341 if 'rejected' in resp:
323344
324345 @defer.inlineCallbacks
325346 def _send_badge(self, badge):
326 logger.info("Sending updated badge count %d to %r", badge, self.user_id)
347 logger.info("Sending updated badge count %d to %s", badge, self.name)
327348 d = {
328349 'notification': {
329350 'id': '',
344365 }
345366 try:
346367 resp = yield self.http_client.post_json_get_json(self.url, d)
347 except:
348 logger.exception("Failed to push %s ", self.url)
368 except Exception:
369 logger.warn(
370 "Failed to send badge count to %s",
371 self.name, exc_info=True,
372 )
349373 defer.returnValue(False)
350374 rejected = []
351375 if 'rejected' in resp:
2626 try:
2727 from synapse.push.emailpusher import EmailPusher
2828 from synapse.push.mailer import Mailer, load_jinja2_templates
29 except:
29 except Exception:
3030 pass
3131
3232
1616 from twisted.internet import defer
1717
1818 from .pusher import PusherFactory
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
19 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
2020 from synapse.util.async import run_on_reactor
2121
2222 import logging
102102 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
103103
104104 @defer.inlineCallbacks
105 def remove_pushers_by_user(self, user_id, except_access_token_id=None):
106 all = yield self.store.get_all_pushers()
107 logger.info(
108 "Removing all pushers for user %s except access tokens id %r",
109 user_id, except_access_token_id
110 )
111 for p in all:
112 if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
105 def remove_pushers_by_access_token(self, user_id, access_tokens):
106 """Remove the pushers for a given user corresponding to a set of
107 access_tokens.
108
109 Args:
110 user_id (str): user to remove pushers for
111 access_tokens (Iterable[int]): access token *ids* to remove pushers
112 for
113 """
114 tokens = set(access_tokens)
115 for p in (yield self.store.get_pushers_by_user_id(user_id)):
116 if p['access_token'] in tokens:
113117 logger.info(
114118 "Removing pusher for app id %s, pushkey %s, user %s",
115119 p['app_id'], p['pushkey'], p['user_name']
116120 )
117 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
121 yield self.remove_pusher(
122 p['app_id'], p['pushkey'], p['user_name'],
123 )
118124
119125 @defer.inlineCallbacks
120126 def on_new_notifications(self, min_stream_id, max_stream_id):
135141 )
136142 )
137143
138 yield preserve_context_over_deferred(defer.gatherResults(deferreds))
139 except:
144 yield make_deferred_yieldable(defer.gatherResults(deferreds))
145 except Exception:
140146 logger.exception("Exception in pusher on_new_notifications")
141147
142148 @defer.inlineCallbacks
160166 preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
161167 )
162168
163 yield preserve_context_over_deferred(defer.gatherResults(deferreds))
164 except:
169 yield make_deferred_yieldable(defer.gatherResults(deferreds))
170 except Exception:
165171 logger.exception("Exception in pusher on_new_receipts")
166172
167173 @defer.inlineCallbacks
187193 for pusherdict in pushers:
188194 try:
189195 p = self.pusher_factory.create_pusher(pusherdict)
190 except:
196 except Exception:
191197 logger.exception("Couldn't start a pusher: caught Exception")
192198 continue
193199 if p:
2323 "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
2424 "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
2525 "signedjson>=1.0.0": ["signedjson>=1.0.0"],
26 "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
26 "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
2727 "service_identity>=1.0.0": ["service_identity>=1.0.0"],
2828 "Twisted>=16.0.0": ["twisted>=16.0.0"],
2929 "pyopenssl>=0.14": ["OpenSSL>=0.14"],
3030 "pyyaml": ["yaml"],
3131 "pyasn1": ["pyasn1"],
3232 "daemonize": ["daemonize"],
33 "bcrypt": ["bcrypt"],
33 "bcrypt": ["bcrypt>=3.1.0"],
3434 "pillow": ["PIL"],
3535 "pydenticon": ["pydenticon"],
3636 "ujson": ["ujson"],
3737 "blist": ["blist"],
38 "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
38 "pysaml2>=3.0.0": ["saml2>=3.0.0"],
3939 "pymacaroons-pynacl": ["pymacaroons"],
4040 "msgpack-python>=0.3.0": ["msgpack"],
4141 "phonenumbers>=8.2.0": ["phonenumbers"],
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from synapse.http.server import JsonResource
16 from synapse.replication.http import membership, send_event
17
18
19 REPLICATION_PREFIX = "/_synapse/replication"
20
21
22 class ReplicationRestResource(JsonResource):
23 def __init__(self, hs):
24 JsonResource.__init__(self, hs, canonical_json=False)
25 self.register_servlets(hs)
26
27 def register_servlets(self, hs):
28 send_event.register_servlets(hs, self)
29 membership.register_servlets(hs, self)
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import logging
16 import re
17
18 from twisted.internet import defer
19
20 from synapse.api.errors import SynapseError, MatrixCodeMessageException
21 from synapse.http.servlet import RestServlet, parse_json_object_from_request
22 from synapse.types import Requester, UserID
23 from synapse.util.distributor import user_left_room, user_joined_room
24
25 logger = logging.getLogger(__name__)
26
27
28 @defer.inlineCallbacks
29 def remote_join(client, host, port, requester, remote_room_hosts,
30 room_id, user_id, content):
31 """Ask the master to do a remote join for the given user to the given room
32
33 Args:
34 client (SimpleHttpClient)
35 host (str): host of master
36 port (int): port on master listening for HTTP replication
37 requester (Requester)
38 remote_room_hosts (list[str]): Servers to try and join via
39 room_id (str)
40 user_id (str)
41 content (dict): The event content to use for the join event
42
43 Returns:
44 Deferred
45 """
46 uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
47
48 payload = {
49 "requester": requester.serialize(),
50 "remote_room_hosts": remote_room_hosts,
51 "room_id": room_id,
52 "user_id": user_id,
53 "content": content,
54 }
55
56 try:
57 result = yield client.post_json_get_json(uri, payload)
58 except MatrixCodeMessageException as e:
59 # We convert to SynapseError as we know that it was a SynapseError
60 # on the master process that we should send to the client. (And
61 # importantly, not stack traces everywhere)
62 raise SynapseError(e.code, e.msg, e.errcode)
63 defer.returnValue(result)
64
65
66 @defer.inlineCallbacks
67 def remote_reject_invite(client, host, port, requester, remote_room_hosts,
68 room_id, user_id):
69 """Ask master to reject the invite for the user and room.
70
71 Args:
72 client (SimpleHttpClient)
73 host (str): host of master
74 port (int): port on master listening for HTTP replication
75 requester (Requester)
76 remote_room_hosts (list[str]): Servers to try and reject via
77 room_id (str)
78 user_id (str)
79
80 Returns:
81 Deferred
82 """
83 uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
84
85 payload = {
86 "requester": requester.serialize(),
87 "remote_room_hosts": remote_room_hosts,
88 "room_id": room_id,
89 "user_id": user_id,
90 }
91
92 try:
93 result = yield client.post_json_get_json(uri, payload)
94 except MatrixCodeMessageException as e:
95 # We convert to SynapseError as we know that it was a SynapseError
96 # on the master process that we should send to the client. (And
97 # importantly, not stack traces everywhere)
98 raise SynapseError(e.code, e.msg, e.errcode)
99 defer.returnValue(result)
100
101
102 @defer.inlineCallbacks
103 def get_or_register_3pid_guest(client, host, port, requester,
104 medium, address, inviter_user_id):
105 """Ask the master to get/create a guest account for given 3PID.
106
107 Args:
108 client (SimpleHttpClient)
109 host (str): host of master
110 port (int): port on master listening for HTTP replication
111 requester (Requester)
112 medium (str)
113 address (str)
114 inviter_user_id (str): The user ID who is trying to invite the
115 3PID
116
117 Returns:
118 Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
119 3PID guest account.
120 """
121
122 uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
123
124 payload = {
125 "requester": requester.serialize(),
126 "medium": medium,
127 "address": address,
128 "inviter_user_id": inviter_user_id,
129 }
130
131 try:
132 result = yield client.post_json_get_json(uri, payload)
133 except MatrixCodeMessageException as e:
134 # We convert to SynapseError as we know that it was a SynapseError
135 # on the master process that we should send to the client. (And
136 # importantly, not stack traces everywhere)
137 raise SynapseError(e.code, e.msg, e.errcode)
138 defer.returnValue(result)
139
140
141 @defer.inlineCallbacks
142 def notify_user_membership_change(client, host, port, user_id, room_id, change):
143 """Notify master that a user has joined or left the room
144
145 Args:
146 client (SimpleHttpClient)
147 host (str): host of master
148 port (int): port on master listening for HTTP replication.
149 user_id (str)
150 room_id (str)
151 change (str): Either "join" or "left"
152
153 Returns:
154 Deferred
155 """
156 assert change in ("joined", "left")
157
158 uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
159
160 payload = {
161 "user_id": user_id,
162 "room_id": room_id,
163 }
164
165 try:
166 result = yield client.post_json_get_json(uri, payload)
167 except MatrixCodeMessageException as e:
168 # We convert to SynapseError as we know that it was a SynapseError
169 # on the master process that we should send to the client. (And
170 # importantly, not stack traces everywhere)
171 raise SynapseError(e.code, e.msg, e.errcode)
172 defer.returnValue(result)
173
174
175 class ReplicationRemoteJoinRestServlet(RestServlet):
176 PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
177
178 def __init__(self, hs):
179 super(ReplicationRemoteJoinRestServlet, self).__init__()
180
181 self.federation_handler = hs.get_handlers().federation_handler
182 self.store = hs.get_datastore()
183 self.clock = hs.get_clock()
184
185 @defer.inlineCallbacks
186 def on_POST(self, request):
187 content = parse_json_object_from_request(request)
188
189 remote_room_hosts = content["remote_room_hosts"]
190 room_id = content["room_id"]
191 user_id = content["user_id"]
192 event_content = content["content"]
193
194 requester = Requester.deserialize(self.store, content["requester"])
195
196 if requester.user:
197 request.authenticated_entity = requester.user.to_string()
198
199 logger.info(
200 "remote_join: %s into room: %s",
201 user_id, room_id,
202 )
203
204 yield self.federation_handler.do_invite_join(
205 remote_room_hosts,
206 room_id,
207 user_id,
208 event_content,
209 )
210
211 defer.returnValue((200, {}))
212
213
214 class ReplicationRemoteRejectInviteRestServlet(RestServlet):
215 PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
216
217 def __init__(self, hs):
218 super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
219
220 self.federation_handler = hs.get_handlers().federation_handler
221 self.store = hs.get_datastore()
222 self.clock = hs.get_clock()
223
224 @defer.inlineCallbacks
225 def on_POST(self, request):
226 content = parse_json_object_from_request(request)
227
228 remote_room_hosts = content["remote_room_hosts"]
229 room_id = content["room_id"]
230 user_id = content["user_id"]
231
232 requester = Requester.deserialize(self.store, content["requester"])
233
234 if requester.user:
235 request.authenticated_entity = requester.user.to_string()
236
237 logger.info(
238 "remote_reject_invite: %s out of room: %s",
239 user_id, room_id,
240 )
241
242 try:
243 event = yield self.federation_handler.do_remotely_reject_invite(
244 remote_room_hosts,
245 room_id,
246 user_id,
247 )
248 ret = event.get_pdu_json()
249 except Exception as e:
250 # if we were unable to reject the exception, just mark
251 # it as rejected on our end and plough ahead.
252 #
253 # The 'except' clause is very broad, but we need to
254 # capture everything from DNS failures upwards
255 #
256 logger.warn("Failed to reject invite: %s", e)
257
258 yield self.store.locally_reject_invite(
259 user_id, room_id
260 )
261 ret = {}
262
263 defer.returnValue((200, ret))
264
265
266 class ReplicationRegister3PIDGuestRestServlet(RestServlet):
267 PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
268
269 def __init__(self, hs):
270 super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
271
272 self.registeration_handler = hs.get_handlers().registration_handler
273 self.store = hs.get_datastore()
274 self.clock = hs.get_clock()
275
276 @defer.inlineCallbacks
277 def on_POST(self, request):
278 content = parse_json_object_from_request(request)
279
280 medium = content["medium"]
281 address = content["address"]
282 inviter_user_id = content["inviter_user_id"]
283
284 requester = Requester.deserialize(self.store, content["requester"])
285
286 if requester.user:
287 request.authenticated_entity = requester.user.to_string()
288
289 logger.info("get_or_register_3pid_guest: %r", content)
290
291 ret = yield self.registeration_handler.get_or_register_3pid_guest(
292 medium, address, inviter_user_id,
293 )
294
295 defer.returnValue((200, ret))
296
297
298 class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
299 PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
300
301 def __init__(self, hs):
302 super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
303
304 self.registeration_handler = hs.get_handlers().registration_handler
305 self.store = hs.get_datastore()
306 self.clock = hs.get_clock()
307 self.distributor = hs.get_distributor()
308
309 def on_POST(self, request, change):
310 content = parse_json_object_from_request(request)
311
312 user_id = content["user_id"]
313 room_id = content["room_id"]
314
315 logger.info("user membership change: %s in %s", user_id, room_id)
316
317 user = UserID.from_string(user_id)
318
319 if change == "joined":
320 user_joined_room(self.distributor, user, room_id)
321 elif change == "left":
322 user_left_room(self.distributor, user, room_id)
323 else:
324 raise Exception("Unrecognized change: %r", change)
325
326 return (200, {})
327
328
329 def register_servlets(hs, http_server):
330 ReplicationRemoteJoinRestServlet(hs).register(http_server)
331 ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
332 ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
333 ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from twisted.internet import defer
16
17 from synapse.api.errors import (
18 SynapseError, MatrixCodeMessageException, CodeMessageException,
19 )
20 from synapse.events import FrozenEvent
21 from synapse.events.snapshot import EventContext
22 from synapse.http.servlet import RestServlet, parse_json_object_from_request
23 from synapse.util.async import sleep
24 from synapse.util.caches.response_cache import ResponseCache
25 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
26 from synapse.util.metrics import Measure
27 from synapse.types import Requester, UserID
28
29 import logging
30 import re
31
32 logger = logging.getLogger(__name__)
33
34
35 @defer.inlineCallbacks
36 def send_event_to_master(client, host, port, requester, event, context,
37 ratelimit, extra_users):
38 """Send event to be handled on the master
39
40 Args:
41 client (SimpleHttpClient)
42 host (str): host of master
43 port (int): port on master listening for HTTP replication
44 requester (Requester)
45 event (FrozenEvent)
46 context (EventContext)
47 ratelimit (bool)
48 extra_users (list(UserID)): Any extra users to notify about event
49 """
50 uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
51 host, port, event.event_id,
52 )
53
54 payload = {
55 "event": event.get_pdu_json(),
56 "internal_metadata": event.internal_metadata.get_dict(),
57 "rejected_reason": event.rejected_reason,
58 "context": context.serialize(event),
59 "requester": requester.serialize(),
60 "ratelimit": ratelimit,
61 "extra_users": [u.to_string() for u in extra_users],
62 }
63
64 try:
65 # We keep retrying the same request for timeouts. This is so that we
66 # have a good idea that the request has either succeeded or failed on
67 # the master, and so whether we should clean up or not.
68 while True:
69 try:
70 result = yield client.put_json(uri, payload)
71 break
72 except CodeMessageException as e:
73 if e.code != 504:
74 raise
75
76 logger.warn("send_event request timed out")
77
78 # If we timed out we probably don't need to worry about backing
79 # off too much, but lets just wait a little anyway.
80 yield sleep(1)
81 except MatrixCodeMessageException as e:
82 # We convert to SynapseError as we know that it was a SynapseError
83 # on the master process that we should send to the client. (And
84 # importantly, not stack traces everywhere)
85 raise SynapseError(e.code, e.msg, e.errcode)
86 defer.returnValue(result)
87
88
89 class ReplicationSendEventRestServlet(RestServlet):
90 """Handles events newly created on workers, including persisting and
91 notifying.
92
93 The API looks like:
94
95 POST /_synapse/replication/send_event/:event_id
96
97 {
98 "event": { .. serialized event .. },
99 "internal_metadata": { .. serialized internal_metadata .. },
100 "rejected_reason": .., // The event.rejected_reason field
101 "context": { .. serialized event context .. },
102 "requester": { .. serialized requester .. },
103 "ratelimit": true,
104 "extra_users": [],
105 }
106 """
107 PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
108
109 def __init__(self, hs):
110 super(ReplicationSendEventRestServlet, self).__init__()
111
112 self.event_creation_handler = hs.get_event_creation_handler()
113 self.store = hs.get_datastore()
114 self.clock = hs.get_clock()
115
116 # The responses are tiny, so we may as well cache them for a while
117 self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000)
118
119 def on_PUT(self, request, event_id):
120 result = self.response_cache.get(event_id)
121 if not result:
122 result = self.response_cache.set(
123 event_id,
124 self._handle_request(request)
125 )
126 else:
127 logger.warn("Returning cached response")
128 return make_deferred_yieldable(result)
129
130 @preserve_fn
131 @defer.inlineCallbacks
132 def _handle_request(self, request):
133 with Measure(self.clock, "repl_send_event_parse"):
134 content = parse_json_object_from_request(request)
135
136 event_dict = content["event"]
137 internal_metadata = content["internal_metadata"]
138 rejected_reason = content["rejected_reason"]
139 event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
140
141 requester = Requester.deserialize(self.store, content["requester"])
142 context = yield EventContext.deserialize(self.store, content["context"])
143
144 ratelimit = content["ratelimit"]
145 extra_users = [UserID.from_string(u) for u in content["extra_users"]]
146
147 if requester.user:
148 request.authenticated_entity = requester.user.to_string()
149
150 logger.info(
151 "Got event to send with ID: %s into room: %s",
152 event.event_id, event.room_id,
153 )
154
155 yield self.event_creation_handler.persist_and_notify_client_event(
156 requester, event, context,
157 ratelimit=ratelimit,
158 extra_users=extra_users,
159 )
160
161 defer.returnValue((200, {}))
162
163
164 def register_servlets(hs, http_server):
165 ReplicationSendEventRestServlet(hs).register(http_server)
2424
2525 class BaseSlavedStore(SQLBaseStore):
2626 def __init__(self, db_conn, hs):
27 super(BaseSlavedStore, self).__init__(hs)
27 super(BaseSlavedStore, self).__init__(db_conn, hs)
2828 if isinstance(self.database_engine, PostgresEngine):
2929 self._cache_id_gen = SlavedIdTracker(
3030 db_conn, "cache_invalidation_stream", "stream_id",
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
15 from ._base import BaseSlavedStore
16 from ._slaved_id_tracker import SlavedIdTracker
17 from synapse.storage import DataStore
18 from synapse.storage.account_data import AccountDataStore
19 from synapse.storage.tags import TagsStore
20 from synapse.util.caches.stream_change_cache import StreamChangeCache
16 from synapse.replication.slave.storage._base import BaseSlavedStore
17 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
18 from synapse.storage.account_data import AccountDataWorkerStore
19 from synapse.storage.tags import TagsWorkerStore
2120
2221
23 class SlavedAccountDataStore(BaseSlavedStore):
22 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
2423
2524 def __init__(self, db_conn, hs):
26 super(SlavedAccountDataStore, self).__init__(db_conn, hs)
2725 self._account_data_id_gen = SlavedIdTracker(
2826 db_conn, "account_data_max_stream_id", "stream_id",
2927 )
30 self._account_data_stream_cache = StreamChangeCache(
31 "AccountDataAndTagsChangeCache",
32 self._account_data_id_gen.get_current_token(),
33 )
3428
35 get_account_data_for_user = (
36 AccountDataStore.__dict__["get_account_data_for_user"]
37 )
38
39 get_global_account_data_by_type_for_users = (
40 AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
41 )
42
43 get_global_account_data_by_type_for_user = (
44 AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
45 )
46
47 get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
48 get_tags_for_room = (
49 DataStore.get_tags_for_room.__func__
50 )
51 get_account_data_for_room = (
52 DataStore.get_account_data_for_room.__func__
53 )
54
55 get_updated_tags = DataStore.get_updated_tags.__func__
56 get_updated_account_data_for_user = (
57 DataStore.get_updated_account_data_for_user.__func__
58 )
29 super(SlavedAccountDataStore, self).__init__(db_conn, hs)
5930
6031 def get_max_account_data_stream_id(self):
6132 return self._account_data_id_gen.get_current_token()
8455 (row.data_type, row.user_id,)
8556 )
8657 self.get_account_data_for_user.invalidate((row.user_id,))
58 self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
59 self.get_account_data_for_room_and_type.invalidate(
60 (row.user_id, row.room_id, row.data_type,),
61 )
8762 self._account_data_stream_cache.entity_has_changed(
8863 row.user_id, token
8964 )
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
15 from ._base import BaseSlavedStore
16 from synapse.storage import DataStore
17 from synapse.config.appservice import load_appservices
18 from synapse.storage.appservice import _make_exclusive_regex
16 from synapse.storage.appservice import (
17 ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
18 )
1919
2020
21 class SlavedApplicationServiceStore(BaseSlavedStore):
22 def __init__(self, db_conn, hs):
23 super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
24 self.services_cache = load_appservices(
25 hs.config.server_name,
26 hs.config.app_service_config_files
27 )
28 self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
29
30 get_app_service_by_token = DataStore.get_app_service_by_token.__func__
31 get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
32 get_app_services = DataStore.get_app_services.__func__
33 get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
34 create_appservice_txn = DataStore.create_appservice_txn.__func__
35 get_appservices_by_state = DataStore.get_appservices_by_state.__func__
36 get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
37 _get_last_txn = DataStore._get_last_txn.__func__
38 complete_appservice_txn = DataStore.complete_appservice_txn.__func__
39 get_appservice_state = DataStore.get_appservice_state.__func__
40 set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
41 set_appservice_state = DataStore.set_appservice_state.__func__
42 get_if_app_services_interested_in_user = (
43 DataStore.get_if_app_services_interested_in_user.__func__
44 )
21 class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
22 ApplicationServiceWorkerStore):
23 pass
1313 # limitations under the License.
1414
1515 from ._base import BaseSlavedStore
16 from synapse.storage.directory import DirectoryStore
16 from synapse.storage.directory import DirectoryWorkerStore
1717
1818
19 class DirectoryStore(BaseSlavedStore):
20 get_aliases_for_room = DirectoryStore.__dict__[
21 "get_aliases_for_room"
22 ]
19 class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
20 pass
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1112 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
15 import logging
16
17 from synapse.api.constants import EventTypes
18 from synapse.storage.event_federation import EventFederationWorkerStore
19 from synapse.storage.event_push_actions import EventPushActionsWorkerStore
20 from synapse.storage.events_worker import EventsWorkerStore
21 from synapse.storage.roommember import RoomMemberWorkerStore
22 from synapse.storage.state import StateGroupWorkerStore
23 from synapse.storage.stream import StreamWorkerStore
24 from synapse.storage.signatures import SignatureWorkerStore
1425 from ._base import BaseSlavedStore
1526 from ._slaved_id_tracker import SlavedIdTracker
16
17 from synapse.api.constants import EventTypes
18 from synapse.storage import DataStore
19 from synapse.storage.roommember import RoomMemberStore
20 from synapse.storage.event_federation import EventFederationStore
21 from synapse.storage.event_push_actions import EventPushActionsStore
22 from synapse.storage.state import StateStore
23 from synapse.storage.stream import StreamStore
24 from synapse.util.caches.stream_change_cache import StreamChangeCache
25
26 import logging
27
2827
2928 logger = logging.getLogger(__name__)
3029
3837 # the method descriptor on the DataStore and chuck them into our class.
3938
4039
41 class SlavedEventStore(BaseSlavedStore):
40 class SlavedEventStore(EventFederationWorkerStore,
41 RoomMemberWorkerStore,
42 EventPushActionsWorkerStore,
43 StreamWorkerStore,
44 EventsWorkerStore,
45 StateGroupWorkerStore,
46 SignatureWorkerStore,
47 BaseSlavedStore):
4248
4349 def __init__(self, db_conn, hs):
44 super(SlavedEventStore, self).__init__(db_conn, hs)
4550 self._stream_id_gen = SlavedIdTracker(
4651 db_conn, "events", "stream_ordering",
4752 )
4853 self._backfill_id_gen = SlavedIdTracker(
4954 db_conn, "events", "stream_ordering", step=-1
5055 )
51 events_max = self._stream_id_gen.get_current_token()
52 event_cache_prefill, min_event_val = self._get_cache_dict(
53 db_conn, "events",
54 entity_column="room_id",
55 stream_column="stream_ordering",
56 max_value=events_max,
57 )
58 self._events_stream_cache = StreamChangeCache(
59 "EventsRoomStreamChangeCache", min_event_val,
60 prefilled_cache=event_cache_prefill,
61 )
62 self._membership_stream_cache = StreamChangeCache(
63 "MembershipStreamChangeCache", events_max,
64 )
6556
66 self.stream_ordering_month_ago = 0
67 self._stream_order_on_start = self.get_room_max_stream_ordering()
57 super(SlavedEventStore, self).__init__(db_conn, hs)
6858
6959 # Cached functions can't be accessed through a class instance so we need
7060 # to reach inside the __dict__ to extract them.
71 get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
72 get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
73 get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"]
74 get_users_who_share_room_with_user = (
75 RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
76 )
77 get_latest_event_ids_in_room = EventFederationStore.__dict__[
78 "get_latest_event_ids_in_room"
79 ]
80 get_invited_rooms_for_user = RoomMemberStore.__dict__[
81 "get_invited_rooms_for_user"
82 ]
83 get_unread_event_push_actions_by_room_for_user = (
84 EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
85 )
86 _get_unread_counts_by_receipt_txn = (
87 DataStore._get_unread_counts_by_receipt_txn.__func__
88 )
89 _get_unread_counts_by_pos_txn = (
90 DataStore._get_unread_counts_by_pos_txn.__func__
91 )
92 _get_state_group_for_events = (
93 StateStore.__dict__["_get_state_group_for_events"]
94 )
95 _get_state_group_for_event = (
96 StateStore.__dict__["_get_state_group_for_event"]
97 )
98 _get_state_groups_from_groups = (
99 StateStore.__dict__["_get_state_groups_from_groups"]
100 )
101 _get_state_groups_from_groups_txn = (
102 DataStore._get_state_groups_from_groups_txn.__func__
103 )
104 get_recent_event_ids_for_room = (
105 StreamStore.__dict__["get_recent_event_ids_for_room"]
106 )
107 get_current_state_ids = (
108 StateStore.__dict__["get_current_state_ids"]
109 )
110 get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
111 _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
112 has_room_changed_since = DataStore.has_room_changed_since.__func__
11361
114 get_unread_push_actions_for_user_in_range_for_http = (
115 DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
116 )
117 get_unread_push_actions_for_user_in_range_for_email = (
118 DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
119 )
120 get_push_action_users_in_range = (
121 DataStore.get_push_action_users_in_range.__func__
122 )
123 get_event = DataStore.get_event.__func__
124 get_events = DataStore.get_events.__func__
125 get_rooms_for_user_where_membership_is = (
126 DataStore.get_rooms_for_user_where_membership_is.__func__
127 )
128 get_membership_changes_for_user = (
129 DataStore.get_membership_changes_for_user.__func__
130 )
131 get_room_events_max_id = DataStore.get_room_events_max_id.__func__
132 get_room_events_stream_for_room = (
133 DataStore.get_room_events_stream_for_room.__func__
134 )
135 get_events_around = DataStore.get_events_around.__func__
136 get_state_for_event = DataStore.get_state_for_event.__func__
137 get_state_for_events = DataStore.get_state_for_events.__func__
138 get_state_groups = DataStore.get_state_groups.__func__
139 get_state_groups_ids = DataStore.get_state_groups_ids.__func__
140 get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
141 get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
142 get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
143 get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
144 _get_joined_users_from_context = (
145 RoomMemberStore.__dict__["_get_joined_users_from_context"]
146 )
62 def get_room_max_stream_ordering(self):
63 return self._stream_id_gen.get_current_token()
14764
148 get_joined_hosts = DataStore.get_joined_hosts.__func__
149 _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"]
150
151 get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
152 get_room_events_stream_for_rooms = (
153 DataStore.get_room_events_stream_for_rooms.__func__
154 )
155 is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
156 get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
157
158 _set_before_and_after = staticmethod(DataStore._set_before_and_after)
159
160 _get_events = DataStore._get_events.__func__
161 _get_events_from_cache = DataStore._get_events_from_cache.__func__
162
163 _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
164 _enqueue_events = DataStore._enqueue_events.__func__
165 _do_fetch = DataStore._do_fetch.__func__
166 _fetch_event_rows = DataStore._fetch_event_rows.__func__
167 _get_event_from_row = DataStore._get_event_from_row.__func__
168 _get_rooms_for_user_where_membership_is_txn = (
169 DataStore._get_rooms_for_user_where_membership_is_txn.__func__
170 )
171 _get_state_for_groups = DataStore._get_state_for_groups.__func__
172 _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
173 _get_events_around_txn = DataStore._get_events_around_txn.__func__
174 _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
175
176 get_backfill_events = DataStore.get_backfill_events.__func__
177 _get_backfill_events = DataStore._get_backfill_events.__func__
178 get_missing_events = DataStore.get_missing_events.__func__
179 _get_missing_events = DataStore._get_missing_events.__func__
180
181 get_auth_chain = DataStore.get_auth_chain.__func__
182 get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
183 _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
184
185 get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
186
187 get_forward_extremeties_for_room = (
188 DataStore.get_forward_extremeties_for_room.__func__
189 )
190 _get_forward_extremeties_for_room = (
191 EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
192 )
193
194 get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
195
196 get_federation_out_pos = DataStore.get_federation_out_pos.__func__
197 update_federation_out_pos = DataStore.update_federation_out_pos.__func__
65 def get_room_min_stream_ordering(self):
66 return self._backfill_id_gen.get_current_token()
19867
19968 def stream_positions(self):
20069 result = super(SlavedEventStore, self).stream_positions()
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from synapse.replication.slave.storage._base import BaseSlavedStore
16 from synapse.storage.profile import ProfileWorkerStore
17
18
19 class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
20 pass
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1415
1516 from .events import SlavedEventStore
1617 from ._slaved_id_tracker import SlavedIdTracker
17 from synapse.storage import DataStore
18 from synapse.storage.push_rule import PushRuleStore
19 from synapse.util.caches.stream_change_cache import StreamChangeCache
18 from synapse.storage.push_rule import PushRulesWorkerStore
2019
2120
22 class SlavedPushRuleStore(SlavedEventStore):
21 class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
2322 def __init__(self, db_conn, hs):
24 super(SlavedPushRuleStore, self).__init__(db_conn, hs)
2523 self._push_rules_stream_id_gen = SlavedIdTracker(
2624 db_conn, "push_rules_stream", "stream_id",
2725 )
28 self.push_rules_stream_cache = StreamChangeCache(
29 "PushRulesStreamChangeCache",
30 self._push_rules_stream_id_gen.get_current_token(),
31 )
32
33 get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
34 get_push_rules_enabled_for_user = (
35 PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
36 )
37 have_push_rules_changed_for_user = (
38 DataStore.have_push_rules_changed_for_user.__func__
39 )
26 super(SlavedPushRuleStore, self).__init__(db_conn, hs)
4027
4128 def get_push_rules_stream_token(self):
4229 return (
4330 self._push_rules_stream_id_gen.get_current_token(),
4431 self._stream_id_gen.get_current_token(),
4532 )
33
34 def get_max_push_rules_stream_id(self):
35 return self._push_rules_stream_id_gen.get_current_token()
4636
4737 def stream_positions(self):
4838 result = super(SlavedPushRuleStore, self).stream_positions()
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1516 from ._base import BaseSlavedStore
1617 from ._slaved_id_tracker import SlavedIdTracker
1718
18 from synapse.storage import DataStore
19 from synapse.storage.pusher import PusherWorkerStore
1920
2021
21 class SlavedPusherStore(BaseSlavedStore):
22 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
2223
2324 def __init__(self, db_conn, hs):
2425 super(SlavedPusherStore, self).__init__(db_conn, hs)
2627 db_conn, "pushers", "id",
2728 extra_tables=[("deleted_pushers", "stream_id")],
2829 )
29
30 get_all_pushers = DataStore.get_all_pushers.__func__
31 get_pushers_by = DataStore.get_pushers_by.__func__
32 get_pushers_by_app_id_and_pushkey = (
33 DataStore.get_pushers_by_app_id_and_pushkey.__func__
34 )
35 _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
3630
3731 def stream_positions(self):
3832 result = super(SlavedPusherStore, self).stream_positions()
00 # -*- coding: utf-8 -*-
11 # Copyright 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1516 from ._base import BaseSlavedStore
1617 from ._slaved_id_tracker import SlavedIdTracker
1718
18 from synapse.storage import DataStore
19 from synapse.storage.receipts import ReceiptsStore
20 from synapse.util.caches.stream_change_cache import StreamChangeCache
19 from synapse.storage.receipts import ReceiptsWorkerStore
2120
2221 # So, um, we want to borrow a load of functions intended for reading from
2322 # a DataStore, but we don't want to take functions that either write to the
2827 # the method descriptor on the DataStore and chuck them into our class.
2928
3029
31 class SlavedReceiptsStore(BaseSlavedStore):
30 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
3231
3332 def __init__(self, db_conn, hs):
34 super(SlavedReceiptsStore, self).__init__(db_conn, hs)
35
33 # We instantiate this first as the ReceiptsWorkerStore constructor
34 # needs to be able to call get_max_receipt_stream_id
3635 self._receipts_id_gen = SlavedIdTracker(
3736 db_conn, "receipts_linearized", "stream_id"
3837 )
3938
40 self._receipts_stream_cache = StreamChangeCache(
41 "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
42 )
39 super(SlavedReceiptsStore, self).__init__(db_conn, hs)
4340
44 get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
45 get_linearized_receipts_for_room = (
46 ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
47 )
48 _get_linearized_receipts_for_rooms = (
49 ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
50 )
51 get_last_receipt_event_id_for_user = (
52 ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
53 )
54
55 get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
56 get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
57
58 get_linearized_receipts_for_rooms = (
59 DataStore.get_linearized_receipts_for_rooms.__func__
60 )
41 def get_max_receipt_stream_id(self):
42 return self._receipts_id_gen.get_current_token()
6143
6244 def stream_positions(self):
6345 result = super(SlavedReceiptsStore, self).stream_positions()
7052 self.get_last_receipt_event_id_for_user.invalidate(
7153 (user_id, room_id, receipt_type)
7254 )
55 self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
56 self.get_receipts_for_room.invalidate((room_id, receipt_type))
7357
7458 def process_replication_rows(self, stream_name, token, rows):
7559 if stream_name == "receipts":
1313 # limitations under the License.
1414
1515 from ._base import BaseSlavedStore
16 from synapse.storage import DataStore
17 from synapse.storage.registration import RegistrationStore
16 from synapse.storage.registration import RegistrationWorkerStore
1817
1918
20 class SlavedRegistrationStore(BaseSlavedStore):
21 def __init__(self, db_conn, hs):
22 super(SlavedRegistrationStore, self).__init__(db_conn, hs)
23
24 # TODO: use the cached version and invalidate deleted tokens
25 get_user_by_access_token = RegistrationStore.__dict__[
26 "get_user_by_access_token"
27 ]
28
29 _query_for_auth = DataStore._query_for_auth.__func__
30 get_user_by_id = RegistrationStore.__dict__[
31 "get_user_by_id"
32 ]
19 class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
20 pass
1313 # limitations under the License.
1414
1515 from ._base import BaseSlavedStore
16 from synapse.storage import DataStore
17 from synapse.storage.room import RoomStore
16 from synapse.storage.room import RoomWorkerStore
1817 from ._slaved_id_tracker import SlavedIdTracker
1918
2019
21 class RoomStore(BaseSlavedStore):
20 class RoomStore(RoomWorkerStore, BaseSlavedStore):
2221 def __init__(self, db_conn, hs):
2322 super(RoomStore, self).__init__(db_conn, hs)
2423 self._public_room_id_gen = SlavedIdTracker(
2524 db_conn, "public_room_list_stream", "stream_id"
2625 )
2726
28 get_public_room_ids = DataStore.get_public_room_ids.__func__
29 get_current_public_room_stream_id = (
30 DataStore.get_current_public_room_stream_id.__func__
31 )
32 get_public_room_ids_at_stream_id = (
33 RoomStore.__dict__["get_public_room_ids_at_stream_id"]
34 )
35 get_public_room_ids_at_stream_id_txn = (
36 DataStore.get_public_room_ids_at_stream_id_txn.__func__
37 )
38 get_published_at_stream_id_txn = (
39 DataStore.get_published_at_stream_id_txn.__func__
40 )
41 get_public_room_changes = DataStore.get_public_room_changes.__func__
27 def get_current_public_room_stream_id(self):
28 return self._public_room_id_gen.get_current_token()
4229
4330 def stream_positions(self):
4431 result = super(RoomStore, self).stream_positions()
1818 """
1919
2020 import logging
21 import ujson as json
21 import simplejson
2222
2323
2424 logger = logging.getLogger(__name__)
9999 return cls(
100100 stream_name,
101101 None if token == "batch" else int(token),
102 json.loads(row_json)
102 simplejson.loads(row_json)
103103 )
104104
105105 def to_line(self):
106106 return " ".join((
107107 self.stream_name,
108108 str(self.token) if self.token is not None else "batch",
109 json.dumps(self.row),
109 simplejson.dumps(self.row, namedtuple_as_object=False),
110110 ))
111111
112112
297297 def from_line(cls, line):
298298 cache_func, keys_json = line.split(" ", 1)
299299
300 return cls(cache_func, json.loads(keys_json))
301
302 def to_line(self):
303 return " ".join((self.cache_func, json.dumps(self.keys)))
300 return cls(cache_func, simplejson.loads(keys_json))
301
302 def to_line(self):
303 return " ".join((
304 self.cache_func, simplejson.dumps(self.keys, namedtuple_as_object=False)
305 ))
304306
305307
306308 class UserIpCommand(Command):
324326 def from_line(cls, line):
325327 user_id, jsn = line.split(" ", 1)
326328
327 access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
329 access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
328330
329331 return cls(
330332 user_id, access_token, ip, user_agent, device_id, last_seen
331333 )
332334
333335 def to_line(self):
334 return self.user_id + " " + json.dumps((
336 return self.user_id + " " + simplejson.dumps((
335337 self.access_token, self.ip, self.user_agent, self.device_id,
336338 self.last_seen,
337339 ))
516516 self.send_error("Wrong remote")
517517
518518 def on_RDATA(self, cmd):
519 stream_name = cmd.stream_name
520 inbound_rdata_count.inc(stream_name)
521
519522 try:
520 row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
523 row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
521524 except Exception:
522525 logger.exception(
523526 "[%s] Failed to parse RDATA: %r %r",
524 self.id(), cmd.stream_name, cmd.row
527 self.id(), stream_name, cmd.row
525528 )
526529 raise
527530
528531 if cmd.token is None:
529532 # I.e. this is part of a batch of updates for this stream. Batch
530533 # until we get an update for the stream with a non None token
531 self.pending_batches.setdefault(cmd.stream_name, []).append(row)
534 self.pending_batches.setdefault(stream_name, []).append(row)
532535 else:
533536 # Check if this is the last of a batch of updates
534 rows = self.pending_batches.pop(cmd.stream_name, [])
537 rows = self.pending_batches.pop(stream_name, [])
535538 rows.append(row)
536539
537 self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
540 self.handler.on_rdata(stream_name, cmd.token, rows)
538541
539542 def on_POSITION(self, cmd):
540543 self.handler.on_position(cmd.stream_name, cmd.token)
643646 },
644647 labels=["command", "name", "conn_id"],
645648 )
649
650 # number of updates received for each RDATA stream
651 inbound_rdata_count = metrics.register_counter(
652 "inbound_rdata_count",
653 labels=["stream_name"],
654 )
161161 )
162162 try:
163163 updates, current_token = yield stream.get_updates()
164 except:
164 except Exception:
165165 logger.info("Failed to handle stream %s", stream.NAME)
166166 raise
167167
215215 self.federation_sender.federation_ack(token)
216216
217217 @measure_func("repl.on_user_sync")
218 @defer.inlineCallbacks
218219 def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
219220 """A client has started/stopped syncing on a worker.
220221 """
221222 user_sync_counter.inc()
222 self.presence_handler.update_external_syncs_row(
223 yield self.presence_handler.update_external_syncs_row(
223224 conn_id, user_id, is_syncing, last_sync_ms,
224225 )
225226
243244 getattr(self.store, cache_func).invalidate(tuple(keys))
244245
245246 @measure_func("repl.on_user_ip")
247 @defer.inlineCallbacks
246248 def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
247249 """The client saw a user request
248250 """
249251 user_ip_cache_counter.inc()
250 self.store.insert_client_ip(
252 yield self.store.insert_client_ip(
251253 user_id, access_token, ip, user_agent, device_id, last_seen,
252254 )
253255
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1516 from twisted.internet import defer
1617
1718 from synapse.api.constants import Membership
18 from synapse.api.errors import AuthError, SynapseError
19 from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
1920 from synapse.types import UserID, create_requester
2021 from synapse.http.servlet import parse_json_object_from_request
2122
112113
113114 class PurgeHistoryRestServlet(ClientV1RestServlet):
114115 PATTERNS = client_path_patterns(
115 "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
116 "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
116117 )
117118
118119 def __init__(self, hs):
120 """
121
122 Args:
123 hs (synapse.server.HomeServer)
124 """
119125 super(PurgeHistoryRestServlet, self).__init__(hs)
120126 self.handlers = hs.get_handlers()
127 self.store = hs.get_datastore()
121128
122129 @defer.inlineCallbacks
123130 def on_POST(self, request, room_id, event_id):
127134 if not is_admin:
128135 raise AuthError(403, "You are not a server admin")
129136
130 yield self.handlers.message_handler.purge_history(room_id, event_id)
131
132 defer.returnValue((200, {}))
137 body = parse_json_object_from_request(request, allow_empty_body=True)
138
139 delete_local_events = bool(body.get("delete_local_events", False))
140
141 # establish the topological ordering we should keep events from. The
142 # user can provide an event_id in the URL or the request body, or can
143 # provide a timestamp in the request body.
144 if event_id is None:
145 event_id = body.get('purge_up_to_event_id')
146
147 if event_id is not None:
148 event = yield self.store.get_event(event_id)
149
150 if event.room_id != room_id:
151 raise SynapseError(400, "Event is for wrong room.")
152
153 depth = event.depth
154 logger.info(
155 "[purge] purging up to depth %i (event_id %s)",
156 depth, event_id,
157 )
158 elif 'purge_up_to_ts' in body:
159 ts = body['purge_up_to_ts']
160 if not isinstance(ts, int):
161 raise SynapseError(
162 400, "purge_up_to_ts must be an int",
163 errcode=Codes.BAD_JSON,
164 )
165
166 stream_ordering = (
167 yield self.store.find_first_stream_ordering_after_ts(ts)
168 )
169
170 (_, depth, _) = (
171 yield self.store.get_room_event_after_stream_ordering(
172 room_id, stream_ordering,
173 )
174 )
175 logger.info(
176 "[purge] purging up to depth %i (received_ts %i => "
177 "stream_ordering %i)",
178 depth, ts, stream_ordering,
179 )
180 else:
181 raise SynapseError(
182 400,
183 "must specify purge_up_to_event_id or purge_up_to_ts",
184 errcode=Codes.BAD_JSON,
185 )
186
187 purge_id = yield self.handlers.message_handler.start_purge_history(
188 room_id, depth,
189 delete_local_events=delete_local_events,
190 )
191
192 defer.returnValue((200, {
193 "purge_id": purge_id,
194 }))
195
196
197 class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
198 PATTERNS = client_path_patterns(
199 "/admin/purge_history_status/(?P<purge_id>[^/]+)"
200 )
201
202 def __init__(self, hs):
203 """
204
205 Args:
206 hs (synapse.server.HomeServer)
207 """
208 super(PurgeHistoryStatusRestServlet, self).__init__(hs)
209 self.handlers = hs.get_handlers()
210
211 @defer.inlineCallbacks
212 def on_GET(self, request, purge_id):
213 requester = yield self.auth.get_user_by_req(request)
214 is_admin = yield self.auth.is_server_admin(requester.user)
215
216 if not is_admin:
217 raise AuthError(403, "You are not a server admin")
218
219 purge_status = self.handlers.message_handler.get_purge_status(purge_id)
220 if purge_status is None:
221 raise NotFoundError("purge id '%s' not found" % purge_id)
222
223 defer.returnValue((200, purge_status.asdict()))
133224
134225
135226 class DeactivateAccountRestServlet(ClientV1RestServlet):
136227 PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
137228
138229 def __init__(self, hs):
139 self.store = hs.get_datastore()
140230 super(DeactivateAccountRestServlet, self).__init__(hs)
231 self._deactivate_account_handler = hs.get_deactivate_account_handler()
141232
142233 @defer.inlineCallbacks
143234 def on_POST(self, request, target_user_id):
148239 if not is_admin:
149240 raise AuthError(403, "You are not a server admin")
150241
151 # FIXME: Theoretically there is a race here wherein user resets password
152 # using threepid.
153 yield self.store.user_delete_access_tokens(target_user_id)
154 yield self.store.user_delete_threepids(target_user_id)
155 yield self.store.user_set_password_hash(target_user_id, None)
156
242 yield self._deactivate_account_handler.deactivate_account(target_user_id)
157243 defer.returnValue((200, {}))
158244
159245
175261 self.store = hs.get_datastore()
176262 self.handlers = hs.get_handlers()
177263 self.state = hs.get_state_handler()
264 self.event_creation_handler = hs.get_event_creation_handler()
265 self.room_member_handler = hs.get_room_member_handler()
178266
179267 @defer.inlineCallbacks
180268 def on_POST(self, request, room_id):
207295 )
208296 new_room_id = info["room_id"]
209297
210 msg_handler = self.handlers.message_handler
211 yield msg_handler.create_and_send_nonmember_event(
298 yield self.event_creation_handler.create_and_send_nonmember_event(
212299 room_creator_requester,
213300 {
214301 "type": "m.room.message",
234321 logger.info("Kicking %r from %r...", user_id, room_id)
235322
236323 target_requester = create_requester(user_id)
237 yield self.handlers.room_member_handler.update_membership(
324 yield self.room_member_handler.update_membership(
238325 requester=target_requester,
239326 target=target_requester.user,
240327 room_id=room_id,
243330 ratelimit=False
244331 )
245332
246 yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
247
248 yield self.handlers.room_member_handler.update_membership(
333 yield self.room_member_handler.forget(target_requester.user, room_id)
334
335 yield self.room_member_handler.update_membership(
249336 requester=target_requester,
250337 target=target_requester.user,
251338 room_id=new_room_id,
291378 )
292379
293380 defer.returnValue((200, {"num_quarantined": num_quarantined}))
381
382
383 class ListMediaInRoom(ClientV1RestServlet):
384 """Lists all of the media in a given room.
385 """
386 PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
387
388 def __init__(self, hs):
389 super(ListMediaInRoom, self).__init__(hs)
390 self.store = hs.get_datastore()
391
392 @defer.inlineCallbacks
393 def on_GET(self, request, room_id):
394 requester = yield self.auth.get_user_by_req(request)
395 is_admin = yield self.auth.is_server_admin(requester.user)
396 if not is_admin:
397 raise AuthError(403, "You are not a server admin")
398
399 local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
400
401 defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
294402
295403
296404 class ResetPasswordRestServlet(ClientV1RestServlet):
313421 super(ResetPasswordRestServlet, self).__init__(hs)
314422 self.hs = hs
315423 self.auth = hs.get_auth()
316 self.auth_handler = hs.get_auth_handler()
424 self._set_password_handler = hs.get_set_password_handler()
317425
318426 @defer.inlineCallbacks
319427 def on_POST(self, request, target_user_id):
334442
335443 logger.info("new_password: %r", new_password)
336444
337 yield self.auth_handler.set_password(
445 yield self._set_password_handler.set_password(
338446 target_user_id, new_password, requester
339447 )
340448 defer.returnValue((200, {}))
483591 def register_servlets(hs, http_server):
484592 WhoisRestServlet(hs).register(http_server)
485593 PurgeMediaCacheRestServlet(hs).register(http_server)
594 PurgeHistoryStatusRestServlet(hs).register(http_server)
486595 DeactivateAccountRestServlet(hs).register(http_server)
487596 PurgeHistoryRestServlet(hs).register(http_server)
488597 UsersRestServlet(hs).register(http_server)
491600 SearchUsersRestServlet(hs).register(http_server)
492601 ShutdownRoomRestServlet(hs).register(http_server)
493602 QuarantineMediaInRoom(hs).register(http_server)
603 ListMediaInRoom(hs).register(http_server)
9292 )
9393 except SynapseError as e:
9494 raise e
95 except:
95 except Exception:
9696 logger.exception("Failed to create association")
9797 raise
9898 except AuthError:
8484
8585 class LoginRestServlet(ClientV1RestServlet):
8686 PATTERNS = client_path_patterns("/login$")
87 PASS_TYPE = "m.login.password"
8887 SAML2_TYPE = "m.login.saml2"
8988 CAS_TYPE = "m.login.cas"
9089 TOKEN_TYPE = "m.login.token"
9392 def __init__(self, hs):
9493 super(LoginRestServlet, self).__init__(hs)
9594 self.idp_redirect_url = hs.config.saml2_idp_redirect_url
96 self.password_enabled = hs.config.password_enabled
9795 self.saml2_enabled = hs.config.saml2_enabled
9896 self.jwt_enabled = hs.config.jwt_enabled
9997 self.jwt_secret = hs.config.jwt_secret
120118 # fall back to the fallback API if they don't understand one of the
121119 # login flow types returned.
122120 flows.append({"type": LoginRestServlet.TOKEN_TYPE})
123 if self.password_enabled:
124 flows.append({"type": LoginRestServlet.PASS_TYPE})
121
122 flows.extend((
123 {"type": t} for t in self.auth_handler.get_supported_login_types()
124 ))
125125
126126 return (200, {"flows": flows})
127127
132132 def on_POST(self, request):
133133 login_submission = parse_json_object_from_request(request)
134134 try:
135 if login_submission["type"] == LoginRestServlet.PASS_TYPE:
136 if not self.password_enabled:
137 raise SynapseError(400, "Password login has been disabled.")
138
139 result = yield self.do_password_login(login_submission)
140 defer.returnValue(result)
141 elif self.saml2_enabled and (login_submission["type"] ==
142 LoginRestServlet.SAML2_TYPE):
135 if self.saml2_enabled and (login_submission["type"] ==
136 LoginRestServlet.SAML2_TYPE):
143137 relay_state = ""
144138 if "relay_state" in login_submission:
145139 relay_state = "&RelayState=" + urllib.quote(
156150 result = yield self.do_token_login(login_submission)
157151 defer.returnValue(result)
158152 else:
159 raise SynapseError(400, "Bad login type.")
153 result = yield self._do_other_login(login_submission)
154 defer.returnValue(result)
160155 except KeyError:
161156 raise SynapseError(400, "Missing JSON keys.")
162157
163158 @defer.inlineCallbacks
164 def do_password_login(self, login_submission):
165 if "password" not in login_submission:
166 raise SynapseError(400, "Missing parameter: password")
167
159 def _do_other_login(self, login_submission):
160 """Handle non-token/saml/jwt logins
161
162 Args:
163 login_submission:
164
165 Returns:
166 (int, object): HTTP code/response
167 """
168 # Log the request we got, but only certain fields to minimise the chance of
169 # logging someone's password (even if they accidentally put it in the wrong
170 # field)
171 logger.info(
172 "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
173 login_submission.get('identifier'),
174 login_submission.get('medium'),
175 login_submission.get('address'),
176 login_submission.get('user'),
177 )
168178 login_submission_legacy_convert(login_submission)
169179
170180 if "identifier" not in login_submission:
180190
181191 # convert threepid identifiers to user IDs
182192 if identifier["type"] == "m.id.thirdparty":
183 if 'medium' not in identifier or 'address' not in identifier:
193 address = identifier.get('address')
194 medium = identifier.get('medium')
195
196 if medium is None or address is None:
184197 raise SynapseError(400, "Invalid thirdparty identifier")
185198
186 address = identifier['address']
187 if identifier['medium'] == 'email':
199 if medium == 'email':
188200 # For emails, transform the address to lowercase.
189201 # We store all email addreses as lowercase in the DB.
190202 # (See add_threepid in synapse/handlers/auth.py)
191203 address = address.lower()
192204 user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
193 identifier['medium'], address
205 medium, address,
194206 )
195207 if not user_id:
208 logger.warn(
209 "unknown 3pid identifier medium %s, address %r",
210 medium, address,
211 )
196212 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
197213
198214 identifier = {
207223 if "user" not in identifier:
208224 raise SynapseError(400, "User identifier is missing 'user' key")
209225
210 user_id = identifier["user"]
211
212 if not user_id.startswith('@'):
213 user_id = UserID.create(
214 user_id, self.hs.hostname
215 ).to_string()
216
217226 auth_handler = self.auth_handler
218 user_id = yield auth_handler.validate_password_login(
219 user_id=user_id,
220 password=login_submission["password"],
221 )
222 device_id = yield self._register_device(user_id, login_submission)
227 canonical_user_id, callback = yield auth_handler.validate_login(
228 identifier["user"],
229 login_submission,
230 )
231
232 device_id = yield self._register_device(
233 canonical_user_id, login_submission,
234 )
223235 access_token = yield auth_handler.get_access_token_for_user_id(
224 user_id, device_id,
225 login_submission.get("initial_device_display_name"),
226 )
236 canonical_user_id, device_id,
237 )
238
227239 result = {
228 "user_id": user_id, # may have changed
240 "user_id": canonical_user_id,
229241 "access_token": access_token,
230242 "home_server": self.hs.hostname,
231243 "device_id": device_id,
232244 }
245
246 if callback is not None:
247 yield callback(result)
233248
234249 defer.returnValue((200, result))
235250
243258 device_id = yield self._register_device(user_id, login_submission)
244259 access_token = yield auth_handler.get_access_token_for_user_id(
245260 user_id, device_id,
246 login_submission.get("initial_device_display_name"),
247261 )
248262 result = {
249263 "user_id": user_id, # may have changed
277291 if user is None:
278292 raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
279293
280 user_id = UserID.create(user, self.hs.hostname).to_string()
294 user_id = UserID(user, self.hs.hostname).to_string()
281295 auth_handler = self.auth_handler
282296 registered_user_id = yield auth_handler.check_user_exists(user_id)
283297 if registered_user_id:
286300 )
287301 access_token = yield auth_handler.get_access_token_for_user_id(
288302 registered_user_id, device_id,
289 login_submission.get("initial_device_display_name"),
290303 )
291304
292305 result = {
443456 if required_value != actual_value:
444457 raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
445458
446 user_id = UserID.create(user, self.hs.hostname).to_string()
459 user_id = UserID(user, self.hs.hostname).to_string()
447460 auth_handler = self.auth_handler
448461 registered_user_id = yield auth_handler.check_user_exists(user_id)
449462 if not registered_user_id:
1515 from twisted.internet import defer
1616
1717 from synapse.api.auth import get_access_token_from_request
18 from synapse.api.errors import AuthError
1819
1920 from .base import ClientV1RestServlet, client_path_patterns
2021
2930
3031 def __init__(self, hs):
3132 super(LogoutRestServlet, self).__init__(hs)
32 self.store = hs.get_datastore()
33 self._auth = hs.get_auth()
34 self._auth_handler = hs.get_auth_handler()
35 self._device_handler = hs.get_device_handler()
3336
3437 def on_OPTIONS(self, request):
3538 return (200, {})
3639
3740 @defer.inlineCallbacks
3841 def on_POST(self, request):
39 access_token = get_access_token_from_request(request)
40 yield self.store.delete_access_token(access_token)
42 try:
43 requester = yield self.auth.get_user_by_req(request)
44 except AuthError:
45 # this implies the access token has already been deleted.
46 pass
47 else:
48 if requester.device_id is None:
49 # the acccess token wasn't associated with a device.
50 # Just delete the access token
51 access_token = get_access_token_from_request(request)
52 yield self._auth_handler.delete_access_token(access_token)
53 else:
54 yield self._device_handler.delete_device(
55 requester.user.to_string(), requester.device_id)
56
4157 defer.returnValue((200, {}))
4258
4359
4662
4763 def __init__(self, hs):
4864 super(LogoutAllRestServlet, self).__init__(hs)
49 self.store = hs.get_datastore()
5065 self.auth = hs.get_auth()
66 self._auth_handler = hs.get_auth_handler()
67 self._device_handler = hs.get_device_handler()
5168
5269 def on_OPTIONS(self, request):
5370 return (200, {})
5673 def on_POST(self, request):
5774 requester = yield self.auth.get_user_by_req(request)
5875 user_id = requester.user.to_string()
59 yield self.store.user_delete_access_tokens(user_id)
76
77 # first delete all of the user's devices
78 yield self._device_handler.delete_all_devices_for_user(user_id)
79
80 # .. and then delete any access tokens which weren't associated with
81 # devices.
82 yield self._auth_handler.delete_access_tokens_for_user(user_id)
6083 defer.returnValue((200, {}))
6184
6285
7777 raise KeyError()
7878 except SynapseError as e:
7979 raise e
80 except:
80 except Exception:
8181 raise SynapseError(400, "Unable to parse state")
8282
8383 yield self.presence_handler.set_state(user, state)
5151
5252 try:
5353 new_name = content["displayname"]
54 except:
54 except Exception:
5555 defer.returnValue((400, "Unable to parse name"))
5656
5757 yield self.profile_handler.set_displayname(
9393 content = parse_json_object_from_request(request)
9494 try:
9595 new_name = content["avatar_url"]
96 except:
96 except Exception:
9797 defer.returnValue((400, "Unable to parse name"))
9898
9999 yield self.profile_handler.set_avatar_url(
6969 self.handlers = hs.get_handlers()
7070
7171 def on_GET(self, request):
72
73 require_email = 'email' in self.hs.config.registrations_require_3pid
74 require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
75
76 flows = []
7277 if self.hs.config.enable_registration_captcha:
73 return (
74 200,
75 {"flows": [
78 # only support the email-only flow if we don't require MSISDN 3PIDs
79 if not require_msisdn:
80 flows.extend([
7681 {
7782 "type": LoginType.RECAPTCHA,
7883 "stages": [
8186 LoginType.PASSWORD
8287 ]
8388 },
89 ])
90 # only support 3PIDless registration if no 3PIDs are required
91 if not require_email and not require_msisdn:
92 flows.extend([
8493 {
8594 "type": LoginType.RECAPTCHA,
8695 "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
8796 }
88 ]}
89 )
97 ])
9098 else:
91 return (
92 200,
93 {"flows": [
99 # only support the email-only flow if we don't require MSISDN 3PIDs
100 if require_email or not require_msisdn:
101 flows.extend([
94102 {
95103 "type": LoginType.EMAIL_IDENTITY,
96104 "stages": [
97105 LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
98106 ]
99 },
107 }
108 ])
109 # only support 3PIDless registration if no 3PIDs are required
110 if not require_email and not require_msisdn:
111 flows.extend([
100112 {
101113 "type": LoginType.PASSWORD
102114 }
103 ]}
104 )
115 ])
116 return (200, {"flows": flows})
105117
106118 @defer.inlineCallbacks
107119 def on_POST(self, request):
358370 if compare_digest(want_mac, got_mac):
359371 handler = self.handlers.registration_handler
360372 user_id, token = yield handler.register(
361 localpart=user,
373 localpart=user.lower(),
362374 password=password,
363375 admin=bool(admin),
364376 )
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
2829
2930 import logging
3031 import urllib
31 import ujson as json
32 import simplejson as json
3233
3334 logger = logging.getLogger(__name__)
3435
8182 def __init__(self, hs):
8283 super(RoomStateEventRestServlet, self).__init__(hs)
8384 self.handlers = hs.get_handlers()
85 self.event_creation_hander = hs.get_event_creation_handler()
86 self.room_member_handler = hs.get_room_member_handler()
8487
8588 def register(self, http_server):
8689 # /room/$roomid/state/$eventtype
153156
154157 if event_type == EventTypes.Member:
155158 membership = content.get("membership", None)
156 event = yield self.handlers.room_member_handler.update_membership(
159 event = yield self.room_member_handler.update_membership(
157160 requester,
158161 target=UserID.from_string(state_key),
159162 room_id=room_id,
161164 content=content,
162165 )
163166 else:
164 msg_handler = self.handlers.message_handler
165 event, context = yield msg_handler.create_event(
167 event, context = yield self.event_creation_hander.create_event(
166168 requester,
167169 event_dict,
168170 token_id=requester.access_token_id,
169171 txn_id=txn_id,
170172 )
171173
172 yield msg_handler.send_nonmember_event(requester, event, context)
174 yield self.event_creation_hander.send_nonmember_event(
175 requester, event, context,
176 )
173177
174178 ret = {}
175179 if event:
182186
183187 def __init__(self, hs):
184188 super(RoomSendEventRestServlet, self).__init__(hs)
185 self.handlers = hs.get_handlers()
189 self.event_creation_hander = hs.get_event_creation_handler()
186190
187191 def register(self, http_server):
188192 # /rooms/$roomid/send/$event_type[/$txn_id]
194198 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
195199 content = parse_json_object_from_request(request)
196200
197 msg_handler = self.handlers.message_handler
198 event = yield msg_handler.create_and_send_nonmember_event(
201 event_dict = {
202 "type": event_type,
203 "content": content,
204 "room_id": room_id,
205 "sender": requester.user.to_string(),
206 }
207
208 if 'ts' in request.args and requester.app_service:
209 event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
210
211 event = yield self.event_creation_hander.create_and_send_nonmember_event(
199212 requester,
200 {
201 "type": event_type,
202 "content": content,
203 "room_id": room_id,
204 "sender": requester.user.to_string(),
205 },
213 event_dict,
206214 txn_id=txn_id,
207215 )
208216
221229 class JoinRoomAliasServlet(ClientV1RestServlet):
222230 def __init__(self, hs):
223231 super(JoinRoomAliasServlet, self).__init__(hs)
224 self.handlers = hs.get_handlers()
232 self.room_member_handler = hs.get_room_member_handler()
225233
226234 def register(self, http_server):
227235 # /join/$room_identifier[/$txn_id]
237245
238246 try:
239247 content = parse_json_object_from_request(request)
240 except:
248 except Exception:
241249 # Turns out we used to ignore the body entirely, and some clients
242250 # cheekily send invalid bodies.
243251 content = {}
246254 room_id = room_identifier
247255 try:
248256 remote_room_hosts = request.args["server_name"]
249 except:
257 except Exception:
250258 remote_room_hosts = None
251259 elif RoomAlias.is_valid(room_identifier):
252 handler = self.handlers.room_member_handler
260 handler = self.room_member_handler
253261 room_alias = RoomAlias.from_string(room_identifier)
254262 room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
255263 room_id = room_id.to_string()
258266 room_identifier,
259267 ))
260268
261 yield self.handlers.room_member_handler.update_membership(
269 yield self.room_member_handler.update_membership(
262270 requester=requester,
263271 target=requester.user,
264272 room_id=room_id,
486494 defer.returnValue((200, content))
487495
488496
489 class RoomEventContext(ClientV1RestServlet):
497 class RoomEventServlet(ClientV1RestServlet):
498 PATTERNS = client_path_patterns(
499 "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
500 )
501
502 def __init__(self, hs):
503 super(RoomEventServlet, self).__init__(hs)
504 self.clock = hs.get_clock()
505 self.event_handler = hs.get_event_handler()
506
507 @defer.inlineCallbacks
508 def on_GET(self, request, room_id, event_id):
509 requester = yield self.auth.get_user_by_req(request)
510 event = yield self.event_handler.get_event(requester.user, event_id)
511
512 time_now = self.clock.time_msec()
513 if event:
514 defer.returnValue((200, serialize_event(event, time_now)))
515 else:
516 defer.returnValue((404, "Event not found."))
517
518
519 class RoomEventContextServlet(ClientV1RestServlet):
490520 PATTERNS = client_path_patterns(
491521 "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
492522 )
493523
494524 def __init__(self, hs):
495 super(RoomEventContext, self).__init__(hs)
525 super(RoomEventContextServlet, self).__init__(hs)
496526 self.clock = hs.get_clock()
497527 self.handlers = hs.get_handlers()
498528
532562 class RoomForgetRestServlet(ClientV1RestServlet):
533563 def __init__(self, hs):
534564 super(RoomForgetRestServlet, self).__init__(hs)
535 self.handlers = hs.get_handlers()
565 self.room_member_handler = hs.get_room_member_handler()
536566
537567 def register(self, http_server):
538568 PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
545575 allow_guest=False,
546576 )
547577
548 yield self.handlers.room_member_handler.forget(
578 yield self.room_member_handler.forget(
549579 user=requester.user,
550580 room_id=room_id,
551581 )
563593
564594 def __init__(self, hs):
565595 super(RoomMembershipRestServlet, self).__init__(hs)
566 self.handlers = hs.get_handlers()
596 self.room_member_handler = hs.get_room_member_handler()
567597
568598 def register(self, http_server):
569599 # /rooms/$roomid/[invite|join|leave]
570600 PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
571 "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
601 "(?P<membership_action>join|invite|leave|ban|unban|kick)")
572602 register_txn_path(self, PATTERNS, http_server)
573603
574604 @defer.inlineCallbacks
586616
587617 try:
588618 content = parse_json_object_from_request(request)
589 except:
619 except Exception:
590620 # Turns out we used to ignore the body entirely, and some clients
591621 # cheekily send invalid bodies.
592622 content = {}
593623
594624 if membership_action == "invite" and self._has_3pid_invite_keys(content):
595 yield self.handlers.room_member_handler.do_3pid_invite(
625 yield self.room_member_handler.do_3pid_invite(
596626 room_id,
597627 requester.user,
598628 content["medium"],
614644 if 'reason' in content and membership_action in ['kick', 'ban']:
615645 event_content = {'reason': content['reason']}
616646
617 yield self.handlers.room_member_handler.update_membership(
647 yield self.room_member_handler.update_membership(
618648 requester=requester,
619649 target=target,
620650 room_id=room_id,
642672 def __init__(self, hs):
643673 super(RoomRedactEventRestServlet, self).__init__(hs)
644674 self.handlers = hs.get_handlers()
675 self.event_creation_handler = hs.get_event_creation_handler()
645676
646677 def register(self, http_server):
647678 PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
652683 requester = yield self.auth.get_user_by_req(request)
653684 content = parse_json_object_from_request(request)
654685
655 msg_handler = self.handlers.message_handler
656 event = yield msg_handler.create_and_send_nonmember_event(
686 event = yield self.event_creation_handler.create_and_send_nonmember_event(
657687 requester,
658688 {
659689 "type": EventTypes.Redaction,
802832 RoomTypingRestServlet(hs).register(http_server)
803833 SearchRestServlet(hs).register(http_server)
804834 JoinedRoomsRestServlet(hs).register(http_server)
805 RoomEventContext(hs).register(http_server)
835 RoomEventServlet(hs).register(http_server)
836 RoomEventContextServlet(hs).register(http_server)
1414
1515 """This module contains base REST classes for constructing client v1 servlets.
1616 """
17
18 from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
17 import logging
1918 import re
2019
21 import logging
20 from twisted.internet import defer
2221
22 from synapse.api.errors import InteractiveAuthIncompleteError
23 from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
2324
2425 logger = logging.getLogger(__name__)
2526
5657 filter_json['room']['timeline']["limit"] = min(
5758 filter_json['room']['timeline']['limit'],
5859 filter_timeline_limit)
60
61
62 def interactive_auth_handler(orig):
63 """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
64
65 Takes a on_POST method which returns a deferred (errcode, body) response
66 and adds exception handling to turn a InteractiveAuthIncompleteError into
67 a 401 response.
68
69 Normal usage is:
70
71 @interactive_auth_handler
72 @defer.inlineCallbacks
73 def on_POST(self, request):
74 # ...
75 yield self.auth_handler.check_auth
76 """
77 def wrapped(*args, **kwargs):
78 res = defer.maybeDeferred(orig, *args, **kwargs)
79 res.addErrback(_catch_incomplete_interactive_auth)
80 return res
81 return wrapped
82
83
84 def _catch_incomplete_interactive_auth(f):
85 """helper for interactive_auth_handler
86
87 Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
88
89 Args:
90 f (failure.Failure):
91 """
92 f.trap(InteractiveAuthIncompleteError)
93 return 401, f.value.result
1212 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
15 import logging
1516
1617 from twisted.internet import defer
1718
19 from synapse.api.auth import has_access_token
1820 from synapse.api.constants import LoginType
19 from synapse.api.errors import LoginError, SynapseError, Codes
21 from synapse.api.errors import Codes, SynapseError
2022 from synapse.http.servlet import (
21 RestServlet, parse_json_object_from_request, assert_params_in_request
23 RestServlet, assert_params_in_request,
24 parse_json_object_from_request,
2225 )
2326 from synapse.util.async import run_on_reactor
2427 from synapse.util.msisdn import phone_number_to_msisdn
25
26 from ._base import client_v2_patterns
27
28 import logging
29
28 from synapse.util.threepids import check_3pid_allowed
29 from ._base import client_v2_patterns, interactive_auth_handler
3030
3131 logger = logging.getLogger(__name__)
3232
4646 assert_params_in_request(body, [
4747 'id_server', 'client_secret', 'email', 'send_attempt'
4848 ])
49
50 if not check_3pid_allowed(self.hs, "email", body['email']):
51 raise SynapseError(
52 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
53 )
4954
5055 existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
5156 'email', body['email']
7883
7984 msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
8085
86 if not check_3pid_allowed(self.hs, "msisdn", msisdn):
87 raise SynapseError(
88 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
89 )
90
8191 existingUid = yield self.datastore.get_user_id_by_threepid(
8292 'msisdn', msisdn
8393 )
98108 self.auth = hs.get_auth()
99109 self.auth_handler = hs.get_auth_handler()
100110 self.datastore = self.hs.get_datastore()
101
102 @defer.inlineCallbacks
103 def on_POST(self, request):
104 yield run_on_reactor()
105
106 body = parse_json_object_from_request(request)
107
108 authed, result, params, _ = yield self.auth_handler.check_auth([
109 [LoginType.PASSWORD],
110 [LoginType.EMAIL_IDENTITY],
111 [LoginType.MSISDN],
112 ], body, self.hs.get_ip_from_request(request))
113
114 if not authed:
115 defer.returnValue((401, result))
116
117 user_id = None
118 requester = None
119
120 if LoginType.PASSWORD in result:
121 # if using password, they should also be logged in
111 self._set_password_handler = hs.get_set_password_handler()
112
113 @interactive_auth_handler
114 @defer.inlineCallbacks
115 def on_POST(self, request):
116 body = parse_json_object_from_request(request)
117
118 # there are two possibilities here. Either the user does not have an
119 # access token, and needs to do a password reset; or they have one and
120 # need to validate their identity.
121 #
122 # In the first case, we offer a couple of means of identifying
123 # themselves (email and msisdn, though it's unclear if msisdn actually
124 # works).
125 #
126 # In the second case, we require a password to confirm their identity.
127
128 if has_access_token(request):
122129 requester = yield self.auth.get_user_by_req(request)
130 params = yield self.auth_handler.validate_user_via_ui_auth(
131 requester, body, self.hs.get_ip_from_request(request),
132 )
123133 user_id = requester.user.to_string()
124 if user_id != result[LoginType.PASSWORD]:
125 raise LoginError(400, "", Codes.UNKNOWN)
126 elif LoginType.EMAIL_IDENTITY in result:
127 threepid = result[LoginType.EMAIL_IDENTITY]
128 if 'medium' not in threepid or 'address' not in threepid:
129 raise SynapseError(500, "Malformed threepid")
130 if threepid['medium'] == 'email':
131 # For emails, transform the address to lowercase.
132 # We store all email addreses as lowercase in the DB.
133 # (See add_threepid in synapse/handlers/auth.py)
134 threepid['address'] = threepid['address'].lower()
135 # if using email, we must know about the email they're authing with!
136 threepid_user_id = yield self.datastore.get_user_id_by_threepid(
137 threepid['medium'], threepid['address']
138 )
139 if not threepid_user_id:
140 raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
141 user_id = threepid_user_id
142134 else:
143 logger.error("Auth succeeded but no known type!", result.keys())
144 raise SynapseError(500, "", Codes.UNKNOWN)
135 requester = None
136 result, params, _ = yield self.auth_handler.check_auth(
137 [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
138 body, self.hs.get_ip_from_request(request),
139 )
140
141 if LoginType.EMAIL_IDENTITY in result:
142 threepid = result[LoginType.EMAIL_IDENTITY]
143 if 'medium' not in threepid or 'address' not in threepid:
144 raise SynapseError(500, "Malformed threepid")
145 if threepid['medium'] == 'email':
146 # For emails, transform the address to lowercase.
147 # We store all email addreses as lowercase in the DB.
148 # (See add_threepid in synapse/handlers/auth.py)
149 threepid['address'] = threepid['address'].lower()
150 # if using email, we must know about the email they're authing with!
151 threepid_user_id = yield self.datastore.get_user_id_by_threepid(
152 threepid['medium'], threepid['address']
153 )
154 if not threepid_user_id:
155 raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
156 user_id = threepid_user_id
157 else:
158 logger.error("Auth succeeded but no known type!", result.keys())
159 raise SynapseError(500, "", Codes.UNKNOWN)
145160
146161 if 'new_password' not in params:
147162 raise SynapseError(400, "", Codes.MISSING_PARAM)
148163 new_password = params['new_password']
149164
150 yield self.auth_handler.set_password(
165 yield self._set_password_handler.set_password(
151166 user_id, new_password, requester
152167 )
153168
161176 PATTERNS = client_v2_patterns("/account/deactivate$")
162177
163178 def __init__(self, hs):
164 self.hs = hs
165 self.store = hs.get_datastore()
179 super(DeactivateAccountRestServlet, self).__init__()
180 self.hs = hs
166181 self.auth = hs.get_auth()
167182 self.auth_handler = hs.get_auth_handler()
168 super(DeactivateAccountRestServlet, self).__init__()
169
170 @defer.inlineCallbacks
171 def on_POST(self, request):
172 body = parse_json_object_from_request(request)
173
174 authed, result, params, _ = yield self.auth_handler.check_auth([
175 [LoginType.PASSWORD],
176 ], body, self.hs.get_ip_from_request(request))
177
178 if not authed:
179 defer.returnValue((401, result))
180
181 user_id = None
182 requester = None
183
184 if LoginType.PASSWORD in result:
185 # if using password, they should also be logged in
186 requester = yield self.auth.get_user_by_req(request)
187 user_id = requester.user.to_string()
188 if user_id != result[LoginType.PASSWORD]:
189 raise LoginError(400, "", Codes.UNKNOWN)
190 else:
191 logger.error("Auth succeeded but no known type!", result.keys())
192 raise SynapseError(500, "", Codes.UNKNOWN)
193
194 # FIXME: Theoretically there is a race here wherein user resets password
195 # using threepid.
196 yield self.store.user_delete_access_tokens(user_id)
197 yield self.store.user_delete_threepids(user_id)
198 yield self.store.user_set_password_hash(user_id, None)
199
183 self._deactivate_account_handler = hs.get_deactivate_account_handler()
184
185 @interactive_auth_handler
186 @defer.inlineCallbacks
187 def on_POST(self, request):
188 body = parse_json_object_from_request(request)
189
190 requester = yield self.auth.get_user_by_req(request)
191
192 # allow ASes to dectivate their own users
193 if requester.app_service:
194 yield self._deactivate_account_handler.deactivate_account(
195 requester.user.to_string()
196 )
197 defer.returnValue((200, {}))
198
199 yield self.auth_handler.validate_user_via_ui_auth(
200 requester, body, self.hs.get_ip_from_request(request),
201 )
202 yield self._deactivate_account_handler.deactivate_account(
203 requester.user.to_string(),
204 )
200205 defer.returnValue((200, {}))
201206
202207
221226
222227 if absent:
223228 raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
229
230 if not check_3pid_allowed(self.hs, "email", body['email']):
231 raise SynapseError(
232 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
233 )
224234
225235 existingUid = yield self.datastore.get_user_id_by_threepid(
226236 'email', body['email']
260270
261271 msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
262272
273 if not check_3pid_allowed(self.hs, "msisdn", msisdn):
274 raise SynapseError(
275 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
276 )
277
263278 existingUid = yield self.datastore.get_user_id_by_threepid(
264279 'msisdn', msisdn
265280 )
370385 )
371386
372387 defer.returnValue((200, {}))
388
389
390 class WhoamiRestServlet(RestServlet):
391 PATTERNS = client_v2_patterns("/account/whoami$")
392
393 def __init__(self, hs):
394 super(WhoamiRestServlet, self).__init__()
395 self.auth = hs.get_auth()
396
397 @defer.inlineCallbacks
398 def on_GET(self, request):
399 requester = yield self.auth.get_user_by_req(request)
400
401 defer.returnValue((200, {'user_id': requester.user.to_string()}))
373402
374403
375404 def register_servlets(hs, http_server):
381410 MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
382411 ThreepidRestServlet(hs).register(http_server)
383412 ThreepidDeleteRestServlet(hs).register(http_server)
413 WhoamiRestServlet(hs).register(http_server)
1616
1717 from twisted.internet import defer
1818
19 from synapse.api import constants, errors
19 from synapse.api import errors
2020 from synapse.http import servlet
21 from ._base import client_v2_patterns
21 from ._base import client_v2_patterns, interactive_auth_handler
2222
2323 logger = logging.getLogger(__name__)
2424
2525
2626 class DevicesRestServlet(servlet.RestServlet):
27 PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
27 PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
2828
2929 def __init__(self, hs):
3030 """
5050 API for bulk deletion of devices. Accepts a JSON object with a devices
5151 key which lists the device_ids to delete. Requires user interactive auth.
5252 """
53 PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
53 PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
5454
5555 def __init__(self, hs):
5656 super(DeleteDevicesRestServlet, self).__init__()
5959 self.device_handler = hs.get_device_handler()
6060 self.auth_handler = hs.get_auth_handler()
6161
62 @interactive_auth_handler
6263 @defer.inlineCallbacks
6364 def on_POST(self, request):
65 requester = yield self.auth.get_user_by_req(request)
66
6467 try:
6568 body = servlet.parse_json_object_from_request(request)
6669 except errors.SynapseError as e:
7679 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
7780 )
7881
79 authed, result, params, _ = yield self.auth_handler.check_auth([
80 [constants.LoginType.PASSWORD],
81 ], body, self.hs.get_ip_from_request(request))
82 yield self.auth_handler.validate_user_via_ui_auth(
83 requester, body, self.hs.get_ip_from_request(request),
84 )
8285
83 if not authed:
84 defer.returnValue((401, result))
85
86 requester = yield self.auth.get_user_by_req(request)
8786 yield self.device_handler.delete_devices(
8887 requester.user.to_string(),
8988 body['devices'],
9291
9392
9493 class DeviceRestServlet(servlet.RestServlet):
95 PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
96 releases=[], v2_alpha=False)
94 PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
9795
9896 def __init__(self, hs):
9997 """
115113 )
116114 defer.returnValue((200, device))
117115
116 @interactive_auth_handler
118117 @defer.inlineCallbacks
119118 def on_DELETE(self, request, device_id):
119 requester = yield self.auth.get_user_by_req(request)
120
120121 try:
121122 body = servlet.parse_json_object_from_request(request)
122123
128129 else:
129130 raise
130131
131 authed, result, params, _ = yield self.auth_handler.check_auth([
132 [constants.LoginType.PASSWORD],
133 ], body, self.hs.get_ip_from_request(request))
132 yield self.auth_handler.validate_user_via_ui_auth(
133 requester, body, self.hs.get_ip_from_request(request),
134 )
134135
135 if not authed:
136 defer.returnValue((401, result))
137
138 requester = yield self.auth.get_user_by_req(request)
139136 yield self.device_handler.delete_device(
140 requester.user.to_string(),
141 device_id,
137 requester.user.to_string(), device_id,
142138 )
143139 defer.returnValue((200, {}))
144140
4949
5050 try:
5151 filter_id = int(filter_id)
52 except:
52 except Exception:
5353 raise SynapseError(400, "Invalid filter_id")
5454
5555 try:
3737
3838 @defer.inlineCallbacks
3939 def on_GET(self, request, group_id):
40 requester = yield self.auth.get_user_by_req(request)
41 user_id = requester.user.to_string()
42
43 group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
40 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
41 requester_user_id = requester.user.to_string()
42
43 group_description = yield self.groups_handler.get_group_profile(
44 group_id,
45 requester_user_id,
46 )
4447
4548 defer.returnValue((200, group_description))
4649
4750 @defer.inlineCallbacks
4851 def on_POST(self, request, group_id):
4952 requester = yield self.auth.get_user_by_req(request)
50 user_id = requester.user.to_string()
53 requester_user_id = requester.user.to_string()
5154
5255 content = parse_json_object_from_request(request)
5356 yield self.groups_handler.update_group_profile(
54 group_id, user_id, content,
57 group_id, requester_user_id, content,
5558 )
5659
5760 defer.returnValue((200, {}))
7073
7174 @defer.inlineCallbacks
7275 def on_GET(self, request, group_id):
73 requester = yield self.auth.get_user_by_req(request)
74 user_id = requester.user.to_string()
75
76 get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
76 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
77 requester_user_id = requester.user.to_string()
78
79 get_group_summary = yield self.groups_handler.get_group_summary(
80 group_id,
81 requester_user_id,
82 )
7783
7884 defer.returnValue((200, get_group_summary))
7985
100106 @defer.inlineCallbacks
101107 def on_PUT(self, request, group_id, category_id, room_id):
102108 requester = yield self.auth.get_user_by_req(request)
103 user_id = requester.user.to_string()
109 requester_user_id = requester.user.to_string()
104110
105111 content = parse_json_object_from_request(request)
106112 resp = yield self.groups_handler.update_group_summary_room(
107 group_id, user_id,
113 group_id, requester_user_id,
108114 room_id=room_id,
109115 category_id=category_id,
110116 content=content,
115121 @defer.inlineCallbacks
116122 def on_DELETE(self, request, group_id, category_id, room_id):
117123 requester = yield self.auth.get_user_by_req(request)
118 user_id = requester.user.to_string()
124 requester_user_id = requester.user.to_string()
119125
120126 resp = yield self.groups_handler.delete_group_summary_room(
121 group_id, user_id,
127 group_id, requester_user_id,
122128 room_id=room_id,
123129 category_id=category_id,
124130 )
141147
142148 @defer.inlineCallbacks
143149 def on_GET(self, request, group_id, category_id):
144 requester = yield self.auth.get_user_by_req(request)
145 user_id = requester.user.to_string()
150 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
151 requester_user_id = requester.user.to_string()
146152
147153 category = yield self.groups_handler.get_group_category(
148 group_id, user_id,
154 group_id, requester_user_id,
149155 category_id=category_id,
150156 )
151157
154160 @defer.inlineCallbacks
155161 def on_PUT(self, request, group_id, category_id):
156162 requester = yield self.auth.get_user_by_req(request)
157 user_id = requester.user.to_string()
163 requester_user_id = requester.user.to_string()
158164
159165 content = parse_json_object_from_request(request)
160166 resp = yield self.groups_handler.update_group_category(
161 group_id, user_id,
167 group_id, requester_user_id,
162168 category_id=category_id,
163169 content=content,
164170 )
168174 @defer.inlineCallbacks
169175 def on_DELETE(self, request, group_id, category_id):
170176 requester = yield self.auth.get_user_by_req(request)
171 user_id = requester.user.to_string()
177 requester_user_id = requester.user.to_string()
172178
173179 resp = yield self.groups_handler.delete_group_category(
174 group_id, user_id,
180 group_id, requester_user_id,
175181 category_id=category_id,
176182 )
177183
193199
194200 @defer.inlineCallbacks
195201 def on_GET(self, request, group_id):
196 requester = yield self.auth.get_user_by_req(request)
197 user_id = requester.user.to_string()
202 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
203 requester_user_id = requester.user.to_string()
198204
199205 category = yield self.groups_handler.get_group_categories(
200 group_id, user_id,
206 group_id, requester_user_id,
201207 )
202208
203209 defer.returnValue((200, category))
218224
219225 @defer.inlineCallbacks
220226 def on_GET(self, request, group_id, role_id):
221 requester = yield self.auth.get_user_by_req(request)
222 user_id = requester.user.to_string()
227 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
228 requester_user_id = requester.user.to_string()
223229
224230 category = yield self.groups_handler.get_group_role(
225 group_id, user_id,
231 group_id, requester_user_id,
226232 role_id=role_id,
227233 )
228234
231237 @defer.inlineCallbacks
232238 def on_PUT(self, request, group_id, role_id):
233239 requester = yield self.auth.get_user_by_req(request)
234 user_id = requester.user.to_string()
240 requester_user_id = requester.user.to_string()
235241
236242 content = parse_json_object_from_request(request)
237243 resp = yield self.groups_handler.update_group_role(
238 group_id, user_id,
244 group_id, requester_user_id,
239245 role_id=role_id,
240246 content=content,
241247 )
245251 @defer.inlineCallbacks
246252 def on_DELETE(self, request, group_id, role_id):
247253 requester = yield self.auth.get_user_by_req(request)
248 user_id = requester.user.to_string()
254 requester_user_id = requester.user.to_string()
249255
250256 resp = yield self.groups_handler.delete_group_role(
251 group_id, user_id,
257 group_id, requester_user_id,
252258 role_id=role_id,
253259 )
254260
270276
271277 @defer.inlineCallbacks
272278 def on_GET(self, request, group_id):
273 requester = yield self.auth.get_user_by_req(request)
274 user_id = requester.user.to_string()
279 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
280 requester_user_id = requester.user.to_string()
275281
276282 category = yield self.groups_handler.get_group_roles(
277 group_id, user_id,
283 group_id, requester_user_id,
278284 )
279285
280286 defer.returnValue((200, category))
341347
342348 @defer.inlineCallbacks
343349 def on_GET(self, request, group_id):
344 requester = yield self.auth.get_user_by_req(request)
345 user_id = requester.user.to_string()
346
347 result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
350 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
351 requester_user_id = requester.user.to_string()
352
353 result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
348354
349355 defer.returnValue((200, result))
350356
362368
363369 @defer.inlineCallbacks
364370 def on_GET(self, request, group_id):
365 requester = yield self.auth.get_user_by_req(request)
366 user_id = requester.user.to_string()
367
368 result = yield self.groups_handler.get_users_in_group(group_id, user_id)
371 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
372 requester_user_id = requester.user.to_string()
373
374 result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
369375
370376 defer.returnValue((200, result))
371377
384390 @defer.inlineCallbacks
385391 def on_GET(self, request, group_id):
386392 requester = yield self.auth.get_user_by_req(request)
387 user_id = requester.user.to_string()
388
389 result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
393 requester_user_id = requester.user.to_string()
394
395 result = yield self.groups_handler.get_invited_users_in_group(
396 group_id,
397 requester_user_id,
398 )
390399
391400 defer.returnValue((200, result))
392401
406415 @defer.inlineCallbacks
407416 def on_POST(self, request):
408417 requester = yield self.auth.get_user_by_req(request)
409 user_id = requester.user.to_string()
418 requester_user_id = requester.user.to_string()
410419
411420 # TODO: Create group on remote server
412421 content = parse_json_object_from_request(request)
413422 localpart = content.pop("localpart")
414 group_id = GroupID.create(localpart, self.server_name).to_string()
415
416 result = yield self.groups_handler.create_group(group_id, user_id, content)
423 group_id = GroupID(localpart, self.server_name).to_string()
424
425 result = yield self.groups_handler.create_group(
426 group_id,
427 requester_user_id,
428 content,
429 )
417430
418431 defer.returnValue((200, result))
419432
434447 @defer.inlineCallbacks
435448 def on_PUT(self, request, group_id, room_id):
436449 requester = yield self.auth.get_user_by_req(request)
437 user_id = requester.user.to_string()
450 requester_user_id = requester.user.to_string()
438451
439452 content = parse_json_object_from_request(request)
440453 result = yield self.groups_handler.add_room_to_group(
441 group_id, user_id, room_id, content,
454 group_id, requester_user_id, room_id, content,
442455 )
443456
444457 defer.returnValue((200, result))
446459 @defer.inlineCallbacks
447460 def on_DELETE(self, request, group_id, room_id):
448461 requester = yield self.auth.get_user_by_req(request)
449 user_id = requester.user.to_string()
462 requester_user_id = requester.user.to_string()
450463
451464 result = yield self.groups_handler.remove_room_from_group(
452 group_id, user_id, room_id,
465 group_id, requester_user_id, room_id,
466 )
467
468 defer.returnValue((200, result))
469
470
471 class GroupAdminRoomsConfigServlet(RestServlet):
472 """Update the config of a room in a group
473 """
474 PATTERNS = client_v2_patterns(
475 "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
476 "/config/(?P<config_key>[^/]*)$"
477 )
478
479 def __init__(self, hs):
480 super(GroupAdminRoomsConfigServlet, self).__init__()
481 self.auth = hs.get_auth()
482 self.clock = hs.get_clock()
483 self.groups_handler = hs.get_groups_local_handler()
484
485 @defer.inlineCallbacks
486 def on_PUT(self, request, group_id, room_id, config_key):
487 requester = yield self.auth.get_user_by_req(request)
488 requester_user_id = requester.user.to_string()
489
490 content = parse_json_object_from_request(request)
491 result = yield self.groups_handler.update_room_in_group(
492 group_id, requester_user_id, room_id, config_key, content,
453493 )
454494
455495 defer.returnValue((200, result))
631671
632672 @defer.inlineCallbacks
633673 def on_GET(self, request, user_id):
634 yield self.auth.get_user_by_req(request)
674 yield self.auth.get_user_by_req(request, allow_guest=True)
635675
636676 result = yield self.groups_handler.get_publicised_groups_for_user(
637677 user_id
656696
657697 @defer.inlineCallbacks
658698 def on_POST(self, request):
659 yield self.auth.get_user_by_req(request)
699 yield self.auth.get_user_by_req(request, allow_guest=True)
660700
661701 content = parse_json_object_from_request(request)
662702 user_ids = content["user_ids"]
683723
684724 @defer.inlineCallbacks
685725 def on_GET(self, request):
686 requester = yield self.auth.get_user_by_req(request)
687 user_id = requester.user.to_string()
688
689 result = yield self.groups_handler.get_joined_groups(user_id)
726 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
727 requester_user_id = requester.user.to_string()
728
729 result = yield self.groups_handler.get_joined_groups(requester_user_id)
690730
691731 defer.returnValue((200, result))
692732
699739 GroupRoomServlet(hs).register(http_server)
700740 GroupCreateServlet(hs).register(http_server)
701741 GroupAdminRoomsServlet(hs).register(http_server)
742 GroupAdminRoomsConfigServlet(hs).register(http_server)
702743 GroupAdminUsersInviteServlet(hs).register(http_server)
703744 GroupAdminUsersKickServlet(hs).register(http_server)
704745 GroupSelfLeaveServlet(hs).register(http_server)
5252 },
5353 }
5454 """
55 PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
56 releases=())
55 PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
5756
5857 def __init__(self, hs):
5958 """
127126 } } } } } }
128127 """
129128
130 PATTERNS = client_v2_patterns(
131 "/keys/query$",
132 releases=()
133 )
129 PATTERNS = client_v2_patterns("/keys/query$")
134130
135131 def __init__(self, hs):
136132 """
159155 200 OK
160156 { "changed": ["@foo:example.com"] }
161157 """
162 PATTERNS = client_v2_patterns(
163 "/keys/changes$",
164 releases=()
165 )
158 PATTERNS = client_v2_patterns("/keys/changes$")
166159
167160 def __init__(self, hs):
168161 """
212205 } } } }
213206
214207 """
215 PATTERNS = client_v2_patterns(
216 "/keys/claim$",
217 releases=()
218 )
208 PATTERNS = client_v2_patterns("/keys/claim$")
219209
220210 def __init__(self, hs):
221211 super(OneTimeKeyServlet, self).__init__()
2929
3030
3131 class NotificationsServlet(RestServlet):
32 PATTERNS = client_v2_patterns("/notifications$", releases=())
32 PATTERNS = client_v2_patterns("/notifications$")
3333
3434 def __init__(self, hs):
3535 super(NotificationsServlet, self).__init__()
2525 RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
2626 )
2727 from synapse.util.msisdn import phone_number_to_msisdn
28
29 from ._base import client_v2_patterns
28 from synapse.util.threepids import check_3pid_allowed
29
30 from ._base import client_v2_patterns, interactive_auth_handler
3031
3132 import logging
3233 import hmac
6970 'id_server', 'client_secret', 'email', 'send_attempt'
7071 ])
7172
73 if not check_3pid_allowed(self.hs, "email", body['email']):
74 raise SynapseError(
75 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
76 )
77
7278 existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
7379 'email', body['email']
7480 )
103109 ])
104110
105111 msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
112
113 if not check_3pid_allowed(self.hs, "msisdn", msisdn):
114 raise SynapseError(
115 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
116 )
106117
107118 existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
108119 'msisdn', msisdn
171182 self.auth_handler = hs.get_auth_handler()
172183 self.registration_handler = hs.get_handlers().registration_handler
173184 self.identity_handler = hs.get_handlers().identity_handler
174 self.room_member_handler = hs.get_handlers().room_member_handler
185 self.room_member_handler = hs.get_room_member_handler()
175186 self.device_handler = hs.get_device_handler()
176187 self.macaroon_gen = hs.get_macaroon_generator()
177188
189 @interactive_auth_handler
178190 @defer.inlineCallbacks
179191 def on_POST(self, request):
180192 yield run_on_reactor()
223235 # 'user' key not 'username'). Since this is a new addition, we'll
224236 # fallback to 'username' if they gave one.
225237 desired_username = body.get("user", desired_username)
238
239 # XXX we should check that desired_username is valid. Currently
240 # we give appservices carte blanche for any insanity in mxids,
241 # because the IRC bridges rely on being able to register stupid
242 # IDs.
243
226244 access_token = get_access_token_from_request(request)
227245
228246 if isinstance(desired_username, basestring):
231249 )
232250 defer.returnValue((200, result)) # we throw for non 200 responses
233251 return
252
253 # for either shared secret or regular registration, downcase the
254 # provided username before attempting to register it. This should mean
255 # that people who try to register with upper-case in their usernames
256 # don't get a nasty surprise. (Note that we treat username
257 # case-insenstively in login, so they are free to carry on imagining
258 # that their username is CrAzYh4cKeR if that keeps them happy)
259 if desired_username is not None:
260 desired_username = desired_username.lower()
234261
235262 # == Shared Secret Registration == (e.g. create new user scripts)
236263 if 'mac' in body:
288315 if 'x_show_msisdn' in body and body['x_show_msisdn']:
289316 show_msisdn = True
290317
318 # FIXME: need a better error than "no auth flow found" for scenarios
319 # where we required 3PID for registration but the user didn't give one
320 require_email = 'email' in self.hs.config.registrations_require_3pid
321 require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
322
323 flows = []
291324 if self.hs.config.enable_registration_captcha:
292 flows = [
293 [LoginType.RECAPTCHA],
294 [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
295 ]
325 # only support 3PIDless registration if no 3PIDs are required
326 if not require_email and not require_msisdn:
327 flows.extend([[LoginType.RECAPTCHA]])
328 # only support the email-only flow if we don't require MSISDN 3PIDs
329 if not require_msisdn:
330 flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
331
296332 if show_msisdn:
333 # only support the MSISDN-only flow if we don't require email 3PIDs
334 if not require_email:
335 flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
336 # always let users provide both MSISDN & email
297337 flows.extend([
298 [LoginType.MSISDN, LoginType.RECAPTCHA],
299338 [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
300339 ])
301340 else:
302 flows = [
303 [LoginType.DUMMY],
304 [LoginType.EMAIL_IDENTITY],
305 ]
341 # only support 3PIDless registration if no 3PIDs are required
342 if not require_email and not require_msisdn:
343 flows.extend([[LoginType.DUMMY]])
344 # only support the email-only flow if we don't require MSISDN 3PIDs
345 if not require_msisdn:
346 flows.extend([[LoginType.EMAIL_IDENTITY]])
347
306348 if show_msisdn:
349 # only support the MSISDN-only flow if we don't require email 3PIDs
350 if not require_email or require_msisdn:
351 flows.extend([[LoginType.MSISDN]])
352 # always let users provide both MSISDN & email
307353 flows.extend([
308 [LoginType.MSISDN],
309 [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
354 [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
310355 ])
311356
312 authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
357 auth_result, params, session_id = yield self.auth_handler.check_auth(
313358 flows, body, self.hs.get_ip_from_request(request)
314359 )
315360
316 if not authed:
317 defer.returnValue((401, auth_result))
318 return
361 # Check that we're not trying to register a denied 3pid.
362 #
363 # the user-facing checks will probably already have happened in
364 # /register/email/requestToken when we requested a 3pid, but that's not
365 # guaranteed.
366
367 if auth_result:
368 for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
369 if login_type in auth_result:
370 medium = auth_result[login_type]['medium']
371 address = auth_result[login_type]['address']
372
373 if not check_3pid_allowed(self.hs, medium, address):
374 raise SynapseError(
375 403, "Third party identifier is not allowed",
376 Codes.THREEPID_DENIED,
377 )
319378
320379 if registered_user_id is not None:
321380 logger.info(
335394 new_password = params.get("password", None)
336395 guest_access_token = params.get("guest_access_token", None)
337396
397 if desired_username is not None:
398 desired_username = desired_username.lower()
399
338400 (registered_user_id, _) = yield self.registration_handler.register(
339401 localpart=desired_username,
340402 password=new_password,
416478 def _do_shared_secret_registration(self, username, password, body):
417479 if not self.hs.config.registration_shared_secret:
418480 raise SynapseError(400, "Shared secret registration is not enabled")
419
420 user = username.encode("utf-8")
481 if not username:
482 raise SynapseError(
483 400, "username must be specified", errcode=Codes.BAD_JSON,
484 )
485
486 # use the username from the original request rather than the
487 # downcased one in `username` for the mac calculation
488 user = body["username"].encode("utf-8")
421489
422490 # str() because otherwise hmac complains that 'unicode' does not
423491 # have the buffer interface
424492 got_mac = str(body["mac"])
425493
494 # FIXME this is different to the /v1/register endpoint, which
495 # includes the password and admin flag in the hashed text. Why are
496 # these different?
426497 want_mac = hmac.new(
427498 key=self.hs.config.registration_shared_secret,
428499 msg=user,
556627 Args:
557628 (str) user_id: full canonical @user:id
558629 (object) params: registration parameters, from which we pull
559 device_id and initial_device_name
630 device_id, initial_device_name and inhibit_login
560631 Returns:
561632 defer.Deferred: (object) dictionary for response from /register
562633 """
563 device_id = yield self._register_device(user_id, params)
564
565 access_token = (
566 yield self.auth_handler.get_access_token_for_user_id(
567 user_id, device_id=device_id,
568 initial_display_name=params.get("initial_device_display_name")
569 )
570 )
571
572 defer.returnValue({
634 result = {
573635 "user_id": user_id,
574 "access_token": access_token,
575636 "home_server": self.hs.hostname,
576 "device_id": device_id,
577 })
637 }
638 if not params.get("inhibit_login", False):
639 device_id = yield self._register_device(user_id, params)
640
641 access_token = (
642 yield self.auth_handler.get_access_token_for_user_id(
643 user_id, device_id=device_id,
644 )
645 )
646
647 result.update({
648 "access_token": access_token,
649 "device_id": device_id,
650 })
651 defer.returnValue(result)
578652
579653 def _register_device(self, user_id, params):
580654 """Register a device for a user.
2828 class SendToDeviceRestServlet(servlet.RestServlet):
2929 PATTERNS = client_v2_patterns(
3030 "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
31 releases=[], v2_alpha=False
31 v2_alpha=False
3232 )
3333
3434 def __init__(self, hs):
3232 import itertools
3333 import logging
3434
35 import ujson as json
35 import simplejson as json
3636
3737 logger = logging.getLogger(__name__)
3838
124124 filter_object = json.loads(filter_id)
125125 set_timeline_upper_limit(filter_object,
126126 self.hs.config.filter_timeline_limit)
127 except:
127 except Exception:
128128 raise SynapseError(400, "Invalid filter JSON")
129129 self.filtering.check_valid_filter(filter_object)
130130 filter = FilterCollection(filter_object)
2525
2626
2727 class ThirdPartyProtocolsServlet(RestServlet):
28 PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
28 PATTERNS = client_v2_patterns("/thirdparty/protocols")
2929
3030 def __init__(self, hs):
3131 super(ThirdPartyProtocolsServlet, self).__init__()
4242
4343
4444 class ThirdPartyProtocolServlet(RestServlet):
45 PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
46 releases=())
45 PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
4746
4847 def __init__(self, hs):
4948 super(ThirdPartyProtocolServlet, self).__init__()
6564
6665
6766 class ThirdPartyUserServlet(RestServlet):
68 PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
69 releases=())
67 PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
7068
7169 def __init__(self, hs):
7270 super(ThirdPartyUserServlet, self).__init__()
8987
9088
9189 class ThirdPartyLocationServlet(RestServlet):
92 PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
93 releases=())
90 PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
9491
9592 def __init__(self, hs):
9693 super(ThirdPartyLocationServlet, self).__init__()
6464
6565 try:
6666 search_term = body["search_term"]
67 except:
67 except Exception:
6868 raise SynapseError(400, "`search_term` is required field")
6969
7070 results = yield self.user_directory_handler.search_users(
2929 "r0.0.1",
3030 "r0.1.0",
3131 "r0.2.0",
32 "r0.3.0",
3233 ]
3334 })
3435
9292 self.store = hs.get_datastore()
9393 self.version_string = hs.version_string
9494 self.clock = hs.get_clock()
95 self.federation_domain_whitelist = hs.config.federation_domain_whitelist
9596
9697 def render_GET(self, request):
9798 self.async_render_GET(request)
136137 logger.info("Handling query for keys %r", query)
137138 store_queries = []
138139 for server_name, key_ids in query.items():
140 if (
141 self.federation_domain_whitelist is not None and
142 server_name not in self.federation_domain_whitelist
143 ):
144 logger.debug("Federation denied with %s", server_name)
145 continue
146
139147 if not key_ids:
140148 key_ids = (None,)
141149 for key_id in key_ids:
212220 )
213221 except KeyLookupError as e:
214222 logger.info("Failed to fetch key: %s", e)
215 except:
223 except Exception:
216224 logger.exception("Failed to get key for %r", server_name)
217225 yield self.query_keys(
218226 request, query, query_remote_on_cache_miss=False
1616 from synapse.api.errors import (
1717 cs_error, Codes, SynapseError
1818 )
19 from synapse.util import logcontext
1920
2021 from twisted.internet import defer
2122 from twisted.protocols.basic import FileSender
4344 except UnicodeDecodeError:
4445 pass
4546 return server_name, media_id, file_name
46 except:
47 except Exception:
4748 raise SynapseError(
4849 404,
4950 "Invalid media id token %r" % (request.postpath,),
6869 logger.debug("Responding with %r", file_path)
6970
7071 if os.path.isfile(file_path):
71 request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
72 if upload_name:
73 if is_ascii(upload_name):
74 request.setHeader(
75 b"Content-Disposition",
76 b"inline; filename=%s" % (
77 urllib.quote(upload_name.encode("utf-8")),
78 ),
79 )
80 else:
81 request.setHeader(
82 b"Content-Disposition",
83 b"inline; filename*=utf-8''%s" % (
84 urllib.quote(upload_name.encode("utf-8")),
85 ),
86 )
87
88 # cache for at least a day.
89 # XXX: we might want to turn this off for data we don't want to
90 # recommend caching as it's sensitive or private - or at least
91 # select private. don't bother setting Expires as all our
92 # clients are smart enough to be happy with Cache-Control
93 request.setHeader(
94 b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
95 )
9672 if file_size is None:
9773 stat = os.stat(file_path)
9874 file_size = stat.st_size
9975
100 request.setHeader(
101 b"Content-Length", b"%d" % (file_size,)
102 )
76 add_file_headers(request, media_type, file_size, upload_name)
10377
10478 with open(file_path, "rb") as f:
105 yield FileSender().beginFileTransfer(f, request)
79 yield logcontext.make_deferred_yieldable(
80 FileSender().beginFileTransfer(f, request)
81 )
10682
10783 finish_request(request)
10884 else:
10985 respond_404(request)
86
87
88 def add_file_headers(request, media_type, file_size, upload_name):
89 """Adds the correct response headers in preparation for responding with the
90 media.
91
92 Args:
93 request (twisted.web.http.Request)
94 media_type (str): The media/content type.
95 file_size (int): Size in bytes of the media, if known.
96 upload_name (str): The name of the requested file, if any.
97 """
98 request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
99 if upload_name:
100 if is_ascii(upload_name):
101 request.setHeader(
102 b"Content-Disposition",
103 b"inline; filename=%s" % (
104 urllib.quote(upload_name.encode("utf-8")),
105 ),
106 )
107 else:
108 request.setHeader(
109 b"Content-Disposition",
110 b"inline; filename*=utf-8''%s" % (
111 urllib.quote(upload_name.encode("utf-8")),
112 ),
113 )
114
115 # cache for at least a day.
116 # XXX: we might want to turn this off for data we don't want to
117 # recommend caching as it's sensitive or private - or at least
118 # select private. don't bother setting Expires as all our
119 # clients are smart enough to be happy with Cache-Control
120 request.setHeader(
121 b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
122 )
123
124 request.setHeader(
125 b"Content-Length", b"%d" % (file_size,)
126 )
127
128
129 @defer.inlineCallbacks
130 def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
131 """Responds to the request with given responder. If responder is None then
132 returns 404.
133
134 Args:
135 request (twisted.web.http.Request)
136 responder (Responder|None)
137 media_type (str): The media/content type.
138 file_size (int|None): Size in bytes of the media. If not known it should be None
139 upload_name (str|None): The name of the requested file, if any.
140 """
141 if not responder:
142 respond_404(request)
143 return
144
145 add_file_headers(request, media_type, file_size, upload_name)
146 with responder:
147 yield responder.write_to_consumer(request)
148 finish_request(request)
149
150
151 class Responder(object):
152 """Represents a response that can be streamed to the requester.
153
154 Responder is a context manager which *must* be used, so that any resources
155 held can be cleaned up.
156 """
157 def write_to_consumer(self, consumer):
158 """Stream response into consumer
159
160 Args:
161 consumer (IConsumer)
162
163 Returns:
164 Deferred: Resolves once the response has finished being written
165 """
166 pass
167
168 def __enter__(self):
169 pass
170
171 def __exit__(self, exc_type, exc_val, exc_tb):
172 pass
173
174
175 class FileInfo(object):
176 """Details about a requested/uploaded file.
177
178 Attributes:
179 server_name (str): The server name where the media originated from,
180 or None if local.
181 file_id (str): The local ID of the file. For local files this is the
182 same as the media_id
183 url_cache (bool): If the file is for the url preview cache
184 thumbnail (bool): Whether the file is a thumbnail or not.
185 thumbnail_width (int)
186 thumbnail_height (int)
187 thumbnail_method (str)
188 thumbnail_type (str): Content type of thumbnail, e.g. image/png
189 """
190 def __init__(self, server_name, file_id, url_cache=False,
191 thumbnail=False, thumbnail_width=None, thumbnail_height=None,
192 thumbnail_method=None, thumbnail_type=None):
193 self.server_name = server_name
194 self.file_id = file_id
195 self.url_cache = url_cache
196 self.thumbnail = thumbnail
197 self.thumbnail_width = thumbnail_width
198 self.thumbnail_height = thumbnail_height
199 self.thumbnail_method = thumbnail_method
200 self.thumbnail_type = thumbnail_type
1313 # limitations under the License.
1414 import synapse.http.servlet
1515
16 from ._base import parse_media_id, respond_with_file, respond_404
16 from ._base import parse_media_id, respond_404
1717 from twisted.web.resource import Resource
1818 from synapse.http.server import request_handler, set_cors_headers
1919
3131 def __init__(self, hs, media_repo):
3232 Resource.__init__(self)
3333
34 self.filepaths = media_repo.filepaths
3534 self.media_repo = media_repo
3635 self.server_name = hs.hostname
37 self.store = hs.get_datastore()
36
37 # Both of these are expected by @request_handler()
38 self.clock = hs.get_clock()
3839 self.version_string = hs.version_string
39 self.clock = hs.get_clock()
4040
4141 def render_GET(self, request):
4242 self._async_render_GET(request)
5656 )
5757 server_name, media_id, name = parse_media_id(request)
5858 if server_name == self.server_name:
59 yield self._respond_local_file(request, media_id, name)
59 yield self.media_repo.get_local_media(request, media_id, name)
6060 else:
61 yield self._respond_remote_file(
62 request, server_name, media_id, name
63 )
61 allow_remote = synapse.http.servlet.parse_boolean(
62 request, "allow_remote", default=True)
63 if not allow_remote:
64 logger.info(
65 "Rejecting request for remote media %s/%s due to allow_remote",
66 server_name, media_id,
67 )
68 respond_404(request)
69 return
6470
65 @defer.inlineCallbacks
66 def _respond_local_file(self, request, media_id, name):
67 media_info = yield self.store.get_local_media(media_id)
68 if not media_info or media_info["quarantined_by"]:
69 respond_404(request)
70 return
71
72 media_type = media_info["media_type"]
73 media_length = media_info["media_length"]
74 upload_name = name if name else media_info["upload_name"]
75 if media_info["url_cache"]:
76 # TODO: Check the file still exists, if it doesn't we can redownload
77 # it from the url `media_info["url_cache"]`
78 file_path = self.filepaths.url_cache_filepath(media_id)
79 else:
80 file_path = self.filepaths.local_media_filepath(media_id)
81
82 yield respond_with_file(
83 request, media_type, file_path, media_length,
84 upload_name=upload_name,
85 )
86
87 @defer.inlineCallbacks
88 def _respond_remote_file(self, request, server_name, media_id, name):
89 # don't forward requests for remote media if allow_remote is false
90 allow_remote = synapse.http.servlet.parse_boolean(
91 request, "allow_remote", default=True)
92 if not allow_remote:
93 logger.info(
94 "Rejecting request for remote media %s/%s due to allow_remote",
95 server_name, media_id,
96 )
97 respond_404(request)
98 return
99
100 media_info = yield self.media_repo.get_remote_media(server_name, media_id)
101
102 media_type = media_info["media_type"]
103 media_length = media_info["media_length"]
104 filesystem_id = media_info["filesystem_id"]
105 upload_name = name if name else media_info["upload_name"]
106
107 file_path = self.filepaths.remote_media_filepath(
108 server_name, filesystem_id
109 )
110
111 yield respond_with_file(
112 request, media_type, file_path, media_length,
113 upload_name=upload_name,
114 )
71 yield self.media_repo.get_remote_media(request, server_name, media_id, name)
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1718 import twisted.web.http
1819 from twisted.web.resource import Resource
1920
21 from ._base import respond_404, FileInfo, respond_with_responder
2022 from .upload_resource import UploadResource
2123 from .download_resource import DownloadResource
2224 from .thumbnail_resource import ThumbnailResource
2426 from .preview_url_resource import PreviewUrlResource
2527 from .filepath import MediaFilePaths
2628 from .thumbnailer import Thumbnailer
29 from .storage_provider import StorageProviderWrapper
30 from .media_storage import MediaStorage
2731
2832 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
2933 from synapse.util.stringutils import random_string
30 from synapse.api.errors import SynapseError, HttpResponseException, \
31 NotFoundError
34 from synapse.api.errors import (
35 SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
36 )
3237
3338 from synapse.util.async import Linearizer
3439 from synapse.util.stringutils import is_ascii
35 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
40 from synapse.util.logcontext import make_deferred_yieldable
3641 from synapse.util.retryutils import NotRetryingDestination
3742
3843 import os
4651 logger = logging.getLogger(__name__)
4752
4853
49 UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
54 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
5055
5156
5257 class MediaRepository(object):
6267 self.primary_base_path = hs.config.media_store_path
6368 self.filepaths = MediaFilePaths(self.primary_base_path)
6469
65 self.backup_base_path = hs.config.backup_media_store_path
66
67 self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
68
6970 self.dynamic_thumbnails = hs.config.dynamic_thumbnails
7071 self.thumbnail_requirements = hs.config.thumbnail_requirements
7172
7273 self.remote_media_linearizer = Linearizer(name="media_remote")
7374
7475 self.recently_accessed_remotes = set()
76 self.recently_accessed_locals = set()
77
78 self.federation_domain_whitelist = hs.config.federation_domain_whitelist
79
80 # List of StorageProviders where we should search for media and
81 # potentially upload to.
82 storage_providers = []
83
84 for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
85 backend = clz(hs, provider_config)
86 provider = StorageProviderWrapper(
87 backend,
88 store_local=wrapper_config.store_local,
89 store_remote=wrapper_config.store_remote,
90 store_synchronous=wrapper_config.store_synchronous,
91 )
92 storage_providers.append(provider)
93
94 self.media_storage = MediaStorage(
95 self.primary_base_path, self.filepaths, storage_providers,
96 )
7597
7698 self.clock.looping_call(
77 self._update_recently_accessed_remotes,
78 UPDATE_RECENTLY_ACCESSED_REMOTES_TS
79 )
80
81 @defer.inlineCallbacks
82 def _update_recently_accessed_remotes(self):
83 media = self.recently_accessed_remotes
99 self._update_recently_accessed,
100 UPDATE_RECENTLY_ACCESSED_TS,
101 )
102
103 @defer.inlineCallbacks
104 def _update_recently_accessed(self):
105 remote_media = self.recently_accessed_remotes
84106 self.recently_accessed_remotes = set()
85107
108 local_media = self.recently_accessed_locals
109 self.recently_accessed_locals = set()
110
86111 yield self.store.update_cached_last_access_time(
87 media, self.clock.time_msec()
88 )
89
90 @staticmethod
91 def _makedirs(filepath):
92 dirname = os.path.dirname(filepath)
93 if not os.path.exists(dirname):
94 os.makedirs(dirname)
95
96 @staticmethod
97 def _write_file_synchronously(source, fname):
98 """Write `source` to the path `fname` synchronously. Should be called
99 from a thread.
112 local_media, remote_media, self.clock.time_msec()
113 )
114
115 def mark_recently_accessed(self, server_name, media_id):
116 """Mark the given media as recently accessed.
100117
101118 Args:
102 source: A file like object to be written
103 fname (str): Path to write to
119 server_name (str|None): Origin server of media, or None if local
120 media_id (str): The media ID of the content
104121 """
105 MediaRepository._makedirs(fname)
106 source.seek(0) # Ensure we read from the start of the file
107 with open(fname, "wb") as f:
108 shutil.copyfileobj(source, f)
109
110 @defer.inlineCallbacks
111 def write_to_file_and_backup(self, source, path):
112 """Write `source` to the on disk media store, and also the backup store
113 if configured.
114
115 Args:
116 source: A file like object that should be written
117 path (str): Relative path to write file to
118
119 Returns:
120 Deferred[str]: the file path written to in the primary media store
121 """
122 fname = os.path.join(self.primary_base_path, path)
123
124 # Write to the main repository
125 yield make_deferred_yieldable(threads.deferToThread(
126 self._write_file_synchronously, source, fname,
127 ))
128
129 # Write to backup repository
130 yield self.copy_to_backup(path)
131
132 defer.returnValue(fname)
133
134 @defer.inlineCallbacks
135 def copy_to_backup(self, path):
136 """Copy a file from the primary to backup media store, if configured.
137
138 Args:
139 path(str): Relative path to write file to
140 """
141 if self.backup_base_path:
142 primary_fname = os.path.join(self.primary_base_path, path)
143 backup_fname = os.path.join(self.backup_base_path, path)
144
145 # We can either wait for successful writing to the backup repository
146 # or write in the background and immediately return
147 if self.synchronous_backup_media_store:
148 yield make_deferred_yieldable(threads.deferToThread(
149 shutil.copyfile, primary_fname, backup_fname,
150 ))
151 else:
152 preserve_fn(threads.deferToThread)(
153 shutil.copyfile, primary_fname, backup_fname,
154 )
122 if server_name:
123 self.recently_accessed_remotes.add((server_name, media_id))
124 else:
125 self.recently_accessed_locals.add(media_id)
155126
156127 @defer.inlineCallbacks
157128 def create_content(self, media_type, upload_name, content, content_length,
170141 """
171142 media_id = random_string(24)
172143
173 fname = yield self.write_to_file_and_backup(
174 content, self.filepaths.local_media_filepath_rel(media_id)
175 )
144 file_info = FileInfo(
145 server_name=None,
146 file_id=media_id,
147 )
148
149 fname = yield self.media_storage.store_file(content, file_info)
176150
177151 logger.info("Stored local media in file %r", fname)
178152
184158 media_length=content_length,
185159 user_id=auth_user,
186160 )
187 media_info = {
188 "media_type": media_type,
189 "media_length": content_length,
190 }
191
192 yield self._generate_thumbnails(None, media_id, media_info)
161
162 yield self._generate_thumbnails(
163 None, media_id, media_id, media_type,
164 )
193165
194166 defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
195167
196168 @defer.inlineCallbacks
197 def get_remote_media(self, server_name, media_id):
169 def get_local_media(self, request, media_id, name):
170 """Responds to reqests for local media, if exists, or returns 404.
171
172 Args:
173 request(twisted.web.http.Request)
174 media_id (str): The media ID of the content. (This is the same as
175 the file_id for local content.)
176 name (str|None): Optional name that, if specified, will be used as
177 the filename in the Content-Disposition header of the response.
178
179 Returns:
180 Deferred: Resolves once a response has successfully been written
181 to request
182 """
183 media_info = yield self.store.get_local_media(media_id)
184 if not media_info or media_info["quarantined_by"]:
185 respond_404(request)
186 return
187
188 self.mark_recently_accessed(None, media_id)
189
190 media_type = media_info["media_type"]
191 media_length = media_info["media_length"]
192 upload_name = name if name else media_info["upload_name"]
193 url_cache = media_info["url_cache"]
194
195 file_info = FileInfo(
196 None, media_id,
197 url_cache=url_cache,
198 )
199
200 responder = yield self.media_storage.fetch_media(file_info)
201 yield respond_with_responder(
202 request, responder, media_type, media_length, upload_name,
203 )
204
205 @defer.inlineCallbacks
206 def get_remote_media(self, request, server_name, media_id, name):
207 """Respond to requests for remote media.
208
209 Args:
210 request(twisted.web.http.Request)
211 server_name (str): Remote server_name where the media originated.
212 media_id (str): The media ID of the content (as defined by the
213 remote server).
214 name (str|None): Optional name that, if specified, will be used as
215 the filename in the Content-Disposition header of the response.
216
217 Returns:
218 Deferred: Resolves once a response has successfully been written
219 to request
220 """
221 if (
222 self.federation_domain_whitelist is not None and
223 server_name not in self.federation_domain_whitelist
224 ):
225 raise FederationDeniedError(server_name)
226
227 self.mark_recently_accessed(server_name, media_id)
228
229 # We linearize here to ensure that we don't try and download remote
230 # media multiple times concurrently
198231 key = (server_name, media_id)
199232 with (yield self.remote_media_linearizer.queue(key)):
200 media_info = yield self._get_remote_media_impl(server_name, media_id)
233 responder, media_info = yield self._get_remote_media_impl(
234 server_name, media_id,
235 )
236
237 # We deliberately stream the file outside the lock
238 if responder:
239 media_type = media_info["media_type"]
240 media_length = media_info["media_length"]
241 upload_name = name if name else media_info["upload_name"]
242 yield respond_with_responder(
243 request, responder, media_type, media_length, upload_name,
244 )
245 else:
246 respond_404(request)
247
248 @defer.inlineCallbacks
249 def get_remote_media_info(self, server_name, media_id):
250 """Gets the media info associated with the remote file, downloading
251 if necessary.
252
253 Args:
254 server_name (str): Remote server_name where the media originated.
255 media_id (str): The media ID of the content (as defined by the
256 remote server).
257
258 Returns:
259 Deferred[dict]: The media_info of the file
260 """
261 if (
262 self.federation_domain_whitelist is not None and
263 server_name not in self.federation_domain_whitelist
264 ):
265 raise FederationDeniedError(server_name)
266
267 # We linearize here to ensure that we don't try and download remote
268 # media multiple times concurrently
269 key = (server_name, media_id)
270 with (yield self.remote_media_linearizer.queue(key)):
271 responder, media_info = yield self._get_remote_media_impl(
272 server_name, media_id,
273 )
274
275 # Ensure we actually use the responder so that it releases resources
276 if responder:
277 with responder:
278 pass
279
201280 defer.returnValue(media_info)
202281
203282 @defer.inlineCallbacks
204283 def _get_remote_media_impl(self, server_name, media_id):
284 """Looks for media in local cache, if not there then attempt to
285 download from remote server.
286
287 Args:
288 server_name (str): Remote server_name where the media originated.
289 media_id (str): The media ID of the content (as defined by the
290 remote server).
291
292 Returns:
293 Deferred[(Responder, media_info)]
294 """
205295 media_info = yield self.store.get_cached_remote_media(
206296 server_name, media_id
207297 )
208 if not media_info:
209 media_info = yield self._download_remote_file(
210 server_name, media_id
211 )
212 elif media_info["quarantined_by"]:
213 raise NotFoundError()
298
299 # file_id is the ID we use to track the file locally. If we've already
300 # seen the file then reuse the existing ID, otherwise genereate a new
301 # one.
302 if media_info:
303 file_id = media_info["filesystem_id"]
214304 else:
215 self.recently_accessed_remotes.add((server_name, media_id))
216 yield self.store.update_cached_last_access_time(
217 [(server_name, media_id)], self.clock.time_msec()
218 )
219 defer.returnValue(media_info)
220
221 @defer.inlineCallbacks
222 def _download_remote_file(self, server_name, media_id):
223 file_id = random_string(24)
224
225 fpath = self.filepaths.remote_media_filepath_rel(
226 server_name, file_id
227 )
228 fname = os.path.join(self.primary_base_path, fpath)
229 self._makedirs(fname)
230
231 try:
232 with open(fname, "wb") as f:
233 request_path = "/".join((
234 "/_matrix/media/v1/download", server_name, media_id,
235 ))
305 file_id = random_string(24)
306
307 file_info = FileInfo(server_name, file_id)
308
309 # If we have an entry in the DB, try and look for it
310 if media_info:
311 if media_info["quarantined_by"]:
312 logger.info("Media is quarantined")
313 raise NotFoundError()
314
315 responder = yield self.media_storage.fetch_media(file_info)
316 if responder:
317 defer.returnValue((responder, media_info))
318
319 # Failed to find the file anywhere, lets download it.
320
321 media_info = yield self._download_remote_file(
322 server_name, media_id, file_id
323 )
324
325 responder = yield self.media_storage.fetch_media(file_info)
326 defer.returnValue((responder, media_info))
327
328 @defer.inlineCallbacks
329 def _download_remote_file(self, server_name, media_id, file_id):
330 """Attempt to download the remote file from the given server name,
331 using the given file_id as the local id.
332
333 Args:
334 server_name (str): Originating server
335 media_id (str): The media ID of the content (as defined by the
336 remote server). This is different than the file_id, which is
337 locally generated.
338 file_id (str): Local file ID
339
340 Returns:
341 Deferred[MediaInfo]
342 """
343
344 file_info = FileInfo(
345 server_name=server_name,
346 file_id=file_id,
347 )
348
349 with self.media_storage.store_into_file(file_info) as (f, fname, finish):
350 request_path = "/".join((
351 "/_matrix/media/v1/download", server_name, media_id,
352 ))
353 try:
354 length, headers = yield self.client.get_file(
355 server_name, request_path, output_stream=f,
356 max_size=self.max_upload_size, args={
357 # tell the remote server to 404 if it doesn't
358 # recognise the server_name, to make sure we don't
359 # end up with a routing loop.
360 "allow_remote": "false",
361 }
362 )
363 except twisted.internet.error.DNSLookupError as e:
364 logger.warn("HTTP error fetching remote media %s/%s: %r",
365 server_name, media_id, e)
366 raise NotFoundError()
367
368 except HttpResponseException as e:
369 logger.warn("HTTP error fetching remote media %s/%s: %s",
370 server_name, media_id, e.response)
371 if e.code == twisted.web.http.NOT_FOUND:
372 raise SynapseError.from_http_response_exception(e)
373 raise SynapseError(502, "Failed to fetch remote media")
374
375 except SynapseError:
376 logger.exception("Failed to fetch remote media %s/%s",
377 server_name, media_id)
378 raise
379 except NotRetryingDestination:
380 logger.warn("Not retrying destination %r", server_name)
381 raise SynapseError(502, "Failed to fetch remote media")
382 except Exception:
383 logger.exception("Failed to fetch remote media %s/%s",
384 server_name, media_id)
385 raise SynapseError(502, "Failed to fetch remote media")
386
387 yield finish()
388
389 media_type = headers["Content-Type"][0]
390
391 time_now_ms = self.clock.time_msec()
392
393 content_disposition = headers.get("Content-Disposition", None)
394 if content_disposition:
395 _, params = cgi.parse_header(content_disposition[0],)
396 upload_name = None
397
398 # First check if there is a valid UTF-8 filename
399 upload_name_utf8 = params.get("filename*", None)
400 if upload_name_utf8:
401 if upload_name_utf8.lower().startswith("utf-8''"):
402 upload_name = upload_name_utf8[7:]
403
404 # If there isn't check for an ascii name.
405 if not upload_name:
406 upload_name_ascii = params.get("filename", None)
407 if upload_name_ascii and is_ascii(upload_name_ascii):
408 upload_name = upload_name_ascii
409
410 if upload_name:
411 upload_name = urlparse.unquote(upload_name)
236412 try:
237 length, headers = yield self.client.get_file(
238 server_name, request_path, output_stream=f,
239 max_size=self.max_upload_size, args={
240 # tell the remote server to 404 if it doesn't
241 # recognise the server_name, to make sure we don't
242 # end up with a routing loop.
243 "allow_remote": "false",
244 }
245 )
246 except twisted.internet.error.DNSLookupError as e:
247 logger.warn("HTTP error fetching remote media %s/%s: %r",
248 server_name, media_id, e)
249 raise NotFoundError()
250
251 except HttpResponseException as e:
252 logger.warn("HTTP error fetching remote media %s/%s: %s",
253 server_name, media_id, e.response)
254 if e.code == twisted.web.http.NOT_FOUND:
255 raise SynapseError.from_http_response_exception(e)
256 raise SynapseError(502, "Failed to fetch remote media")
257
258 except SynapseError:
259 logger.exception("Failed to fetch remote media %s/%s",
260 server_name, media_id)
261 raise
262 except NotRetryingDestination:
263 logger.warn("Not retrying destination %r", server_name)
264 raise SynapseError(502, "Failed to fetch remote media")
265 except Exception:
266 logger.exception("Failed to fetch remote media %s/%s",
267 server_name, media_id)
268 raise SynapseError(502, "Failed to fetch remote media")
269
270 yield self.copy_to_backup(fpath)
271
272 media_type = headers["Content-Type"][0]
273 time_now_ms = self.clock.time_msec()
274
275 content_disposition = headers.get("Content-Disposition", None)
276 if content_disposition:
277 _, params = cgi.parse_header(content_disposition[0],)
278 upload_name = None
279
280 # First check if there is a valid UTF-8 filename
281 upload_name_utf8 = params.get("filename*", None)
282 if upload_name_utf8:
283 if upload_name_utf8.lower().startswith("utf-8''"):
284 upload_name = upload_name_utf8[7:]
285
286 # If there isn't check for an ascii name.
287 if not upload_name:
288 upload_name_ascii = params.get("filename", None)
289 if upload_name_ascii and is_ascii(upload_name_ascii):
290 upload_name = upload_name_ascii
291
292 if upload_name:
293 upload_name = urlparse.unquote(upload_name)
294 try:
295 upload_name = upload_name.decode("utf-8")
296 except UnicodeDecodeError:
297 upload_name = None
298 else:
299 upload_name = None
300
301 logger.info("Stored remote media in file %r", fname)
302
303 yield self.store.store_cached_remote_media(
304 origin=server_name,
305 media_id=media_id,
306 media_type=media_type,
307 time_now_ms=self.clock.time_msec(),
308 upload_name=upload_name,
309 media_length=length,
310 filesystem_id=file_id,
311 )
312 except:
313 os.remove(fname)
314 raise
413 upload_name = upload_name.decode("utf-8")
414 except UnicodeDecodeError:
415 upload_name = None
416 else:
417 upload_name = None
418
419 logger.info("Stored remote media in file %r", fname)
420
421 yield self.store.store_cached_remote_media(
422 origin=server_name,
423 media_id=media_id,
424 media_type=media_type,
425 time_now_ms=self.clock.time_msec(),
426 upload_name=upload_name,
427 media_length=length,
428 filesystem_id=file_id,
429 )
315430
316431 media_info = {
317432 "media_type": media_type,
322437 }
323438
324439 yield self._generate_thumbnails(
325 server_name, media_id, media_info
440 server_name, media_id, file_id, media_type,
326441 )
327442
328443 defer.returnValue(media_info)
356471
357472 @defer.inlineCallbacks
358473 def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
359 t_method, t_type):
360 input_path = self.filepaths.local_media_filepath(media_id)
474 t_method, t_type, url_cache):
475 input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
476 None, media_id, url_cache=url_cache,
477 ))
361478
362479 thumbnailer = Thumbnailer(input_path)
363480 t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
367484
368485 if t_byte_source:
369486 try:
370 output_path = yield self.write_to_file_and_backup(
371 t_byte_source,
372 self.filepaths.local_media_thumbnail_rel(
373 media_id, t_width, t_height, t_type, t_method
374 )
487 file_info = FileInfo(
488 server_name=None,
489 file_id=media_id,
490 url_cache=url_cache,
491 thumbnail=True,
492 thumbnail_width=t_width,
493 thumbnail_height=t_height,
494 thumbnail_method=t_method,
495 thumbnail_type=t_type,
496 )
497
498 output_path = yield self.media_storage.store_file(
499 t_byte_source, file_info,
375500 )
376501 finally:
377502 t_byte_source.close()
389514 @defer.inlineCallbacks
390515 def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
391516 t_width, t_height, t_method, t_type):
392 input_path = self.filepaths.remote_media_filepath(server_name, file_id)
517 input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
518 server_name, file_id, url_cache=False,
519 ))
393520
394521 thumbnailer = Thumbnailer(input_path)
395522 t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
399526
400527 if t_byte_source:
401528 try:
402 output_path = yield self.write_to_file_and_backup(
403 t_byte_source,
404 self.filepaths.remote_media_thumbnail_rel(
405 server_name, file_id, t_width, t_height, t_type, t_method
406 )
529 file_info = FileInfo(
530 server_name=server_name,
531 file_id=media_id,
532 thumbnail=True,
533 thumbnail_width=t_width,
534 thumbnail_height=t_height,
535 thumbnail_method=t_method,
536 thumbnail_type=t_type,
537 )
538
539 output_path = yield self.media_storage.store_file(
540 t_byte_source, file_info,
407541 )
408542 finally:
409543 t_byte_source.close()
420554 defer.returnValue(output_path)
421555
422556 @defer.inlineCallbacks
423 def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
557 def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
558 url_cache=False):
424559 """Generate and store thumbnails for an image.
425560
426561 Args:
427 server_name(str|None): The server name if remote media, else None if local
428 media_id(str)
429 media_info(dict)
430 url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
562 server_name (str|None): The server name if remote media, else None if local
563 media_id (str): The media ID of the content. (This is the same as
564 the file_id for local content)
565 file_id (str): Local file ID
566 media_type (str): The content type of the file
567 url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
431568 used exclusively by the url previewer
432569
433570 Returns:
434571 Deferred[dict]: Dict with "width" and "height" keys of original image
435572 """
436 media_type = media_info["media_type"]
437 file_id = media_info.get("filesystem_id")
438573 requirements = self._get_thumbnail_requirements(media_type)
439574 if not requirements:
440575 return
441576
442 if server_name:
443 input_path = self.filepaths.remote_media_filepath(server_name, file_id)
444 elif url_cache:
445 input_path = self.filepaths.url_cache_filepath(media_id)
446 else:
447 input_path = self.filepaths.local_media_filepath(media_id)
577 input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
578 server_name, file_id, url_cache=url_cache,
579 ))
448580
449581 thumbnailer = Thumbnailer(input_path)
450582 m_width = thumbnailer.width
471603
472604 # Now we generate the thumbnails for each dimension, store it
473605 for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
474 # Work out the correct file name for thumbnail
475 if server_name:
476 file_path = self.filepaths.remote_media_thumbnail_rel(
477 server_name, file_id, t_width, t_height, t_type, t_method
478 )
479 elif url_cache:
480 file_path = self.filepaths.url_cache_thumbnail_rel(
481 media_id, t_width, t_height, t_type, t_method
482 )
483 else:
484 file_path = self.filepaths.local_media_thumbnail_rel(
485 media_id, t_width, t_height, t_type, t_method
486 )
487
488606 # Generate the thumbnail
489607 if t_method == "crop":
490608 t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
504622 continue
505623
506624 try:
507 # Write to disk
508 output_path = yield self.write_to_file_and_backup(
509 t_byte_source, file_path,
625 file_info = FileInfo(
626 server_name=server_name,
627 file_id=file_id,
628 thumbnail=True,
629 thumbnail_width=t_width,
630 thumbnail_height=t_height,
631 thumbnail_method=t_method,
632 thumbnail_type=t_type,
633 url_cache=url_cache,
634 )
635
636 output_path = yield self.media_storage.store_file(
637 t_byte_source, file_info,
510638 )
511639 finally:
512640 t_byte_source.close()
619747
620748 self.putChild("upload", UploadResource(hs, media_repo))
621749 self.putChild("download", DownloadResource(hs, media_repo))
622 self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
750 self.putChild("thumbnail", ThumbnailResource(
751 hs, media_repo, media_repo.media_storage,
752 ))
623753 self.putChild("identicon", IdenticonResource())
624754 if hs.config.url_preview_enabled:
625 self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
755 self.putChild("preview_url", PreviewUrlResource(
756 hs, media_repo, media_repo.media_storage,
757 ))
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vecotr Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from twisted.internet import defer, threads
16 from twisted.protocols.basic import FileSender
17
18 from ._base import Responder
19
20 from synapse.util.file_consumer import BackgroundFileConsumer
21 from synapse.util.logcontext import make_deferred_yieldable
22
23 import contextlib
24 import os
25 import logging
26 import shutil
27 import sys
28
29
30 logger = logging.getLogger(__name__)
31
32
33 class MediaStorage(object):
34 """Responsible for storing/fetching files from local sources.
35
36 Args:
37 local_media_directory (str): Base path where we store media on disk
38 filepaths (MediaFilePaths)
39 storage_providers ([StorageProvider]): List of StorageProvider that are
40 used to fetch and store files.
41 """
42
43 def __init__(self, local_media_directory, filepaths, storage_providers):
44 self.local_media_directory = local_media_directory
45 self.filepaths = filepaths
46 self.storage_providers = storage_providers
47
48 @defer.inlineCallbacks
49 def store_file(self, source, file_info):
50 """Write `source` to the on disk media store, and also any other
51 configured storage providers
52
53 Args:
54 source: A file like object that should be written
55 file_info (FileInfo): Info about the file to store
56
57 Returns:
58 Deferred[str]: the file path written to in the primary media store
59 """
60
61 with self.store_into_file(file_info) as (f, fname, finish_cb):
62 # Write to the main repository
63 yield make_deferred_yieldable(threads.deferToThread(
64 _write_file_synchronously, source, f,
65 ))
66 yield finish_cb()
67
68 defer.returnValue(fname)
69
70 @contextlib.contextmanager
71 def store_into_file(self, file_info):
72 """Context manager used to get a file like object to write into, as
73 described by file_info.
74
75 Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
76 like object that can be written to, fname is the absolute path of file
77 on disk, and finish_cb is a function that returns a Deferred.
78
79 fname can be used to read the contents from after upload, e.g. to
80 generate thumbnails.
81
82 finish_cb must be called and waited on after the file has been
83 successfully been written to. Should not be called if there was an
84 error.
85
86 Args:
87 file_info (FileInfo): Info about the file to store
88
89 Example:
90
91 with media_storage.store_into_file(info) as (f, fname, finish_cb):
92 # .. write into f ...
93 yield finish_cb()
94 """
95
96 path = self._file_info_to_path(file_info)
97 fname = os.path.join(self.local_media_directory, path)
98
99 dirname = os.path.dirname(fname)
100 if not os.path.exists(dirname):
101 os.makedirs(dirname)
102
103 finished_called = [False]
104
105 @defer.inlineCallbacks
106 def finish():
107 for provider in self.storage_providers:
108 yield provider.store_file(path, file_info)
109
110 finished_called[0] = True
111
112 try:
113 with open(fname, "wb") as f:
114 yield f, fname, finish
115 except Exception:
116 t, v, tb = sys.exc_info()
117 try:
118 os.remove(fname)
119 except Exception:
120 pass
121 raise t, v, tb
122
123 if not finished_called:
124 raise Exception("Finished callback not called")
125
126 @defer.inlineCallbacks
127 def fetch_media(self, file_info):
128 """Attempts to fetch media described by file_info from the local cache
129 and configured storage providers.
130
131 Args:
132 file_info (FileInfo)
133
134 Returns:
135 Deferred[Responder|None]: Returns a Responder if the file was found,
136 otherwise None.
137 """
138
139 path = self._file_info_to_path(file_info)
140 local_path = os.path.join(self.local_media_directory, path)
141 if os.path.exists(local_path):
142 defer.returnValue(FileResponder(open(local_path, "rb")))
143
144 for provider in self.storage_providers:
145 res = yield provider.fetch(path, file_info)
146 if res:
147 defer.returnValue(res)
148
149 defer.returnValue(None)
150
151 @defer.inlineCallbacks
152 def ensure_media_is_in_local_cache(self, file_info):
153 """Ensures that the given file is in the local cache. Attempts to
154 download it from storage providers if it isn't.
155
156 Args:
157 file_info (FileInfo)
158
159 Returns:
160 Deferred[str]: Full path to local file
161 """
162 path = self._file_info_to_path(file_info)
163 local_path = os.path.join(self.local_media_directory, path)
164 if os.path.exists(local_path):
165 defer.returnValue(local_path)
166
167 dirname = os.path.dirname(local_path)
168 if not os.path.exists(dirname):
169 os.makedirs(dirname)
170
171 for provider in self.storage_providers:
172 res = yield provider.fetch(path, file_info)
173 if res:
174 with res:
175 consumer = BackgroundFileConsumer(open(local_path, "w"))
176 yield res.write_to_consumer(consumer)
177 yield consumer.wait()
178 defer.returnValue(local_path)
179
180 raise Exception("file could not be found")
181
182 def _file_info_to_path(self, file_info):
183 """Converts file_info into a relative path.
184
185 The path is suitable for storing files under a directory, e.g. used to
186 store files on local FS under the base media repository directory.
187
188 Args:
189 file_info (FileInfo)
190
191 Returns:
192 str
193 """
194 if file_info.url_cache:
195 if file_info.thumbnail:
196 return self.filepaths.url_cache_thumbnail_rel(
197 media_id=file_info.file_id,
198 width=file_info.thumbnail_width,
199 height=file_info.thumbnail_height,
200 content_type=file_info.thumbnail_type,
201 method=file_info.thumbnail_method,
202 )
203 return self.filepaths.url_cache_filepath_rel(file_info.file_id)
204
205 if file_info.server_name:
206 if file_info.thumbnail:
207 return self.filepaths.remote_media_thumbnail_rel(
208 server_name=file_info.server_name,
209 file_id=file_info.file_id,
210 width=file_info.thumbnail_width,
211 height=file_info.thumbnail_height,
212 content_type=file_info.thumbnail_type,
213 method=file_info.thumbnail_method
214 )
215 return self.filepaths.remote_media_filepath_rel(
216 file_info.server_name, file_info.file_id,
217 )
218
219 if file_info.thumbnail:
220 return self.filepaths.local_media_thumbnail_rel(
221 media_id=file_info.file_id,
222 width=file_info.thumbnail_width,
223 height=file_info.thumbnail_height,
224 content_type=file_info.thumbnail_type,
225 method=file_info.thumbnail_method
226 )
227 return self.filepaths.local_media_filepath_rel(
228 file_info.file_id,
229 )
230
231
232 def _write_file_synchronously(source, dest):
233 """Write `source` to the file like `dest` synchronously. Should be called
234 from a thread.
235
236 Args:
237 source: A file like object that's to be written
238 dest: A file like object to be written to
239 """
240 source.seek(0) # Ensure we read from the start of the file
241 shutil.copyfileobj(source, dest)
242
243
244 class FileResponder(Responder):
245 """Wraps an open file that can be sent to a request.
246
247 Args:
248 open_file (file): A file like object to be streamed ot the client,
249 is closed when finished streaming.
250 """
251 def __init__(self, open_file):
252 self.open_file = open_file
253
254 def write_to_consumer(self, consumer):
255 return FileSender().beginFileTransfer(self.open_file, consumer)
256
257 def __exit__(self, exc_type, exc_val, exc_tb):
258 self.open_file.close()
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import cgi
15 import datetime
16 import errno
17 import fnmatch
18 import itertools
19 import logging
20 import os
21 import re
22 import shutil
23 import sys
24 import traceback
25 import simplejson as json
26 import urlparse
1427
1528 from twisted.web.server import NOT_DONE_YET
1629 from twisted.internet import defer
1730 from twisted.web.resource import Resource
1831
32 from ._base import FileInfo
33
1934 from synapse.api.errors import (
2035 SynapseError, Codes,
2136 )
37 from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
2238 from synapse.util.stringutils import random_string
2339 from synapse.util.caches.expiringcache import ExpiringCache
2440 from synapse.http.client import SpiderHttpClient
2541 from synapse.http.server import (
26 request_handler, respond_with_json_bytes
42 request_handler, respond_with_json_bytes,
43 respond_with_json,
2744 )
2845 from synapse.util.async import ObservableDeferred
2946 from synapse.util.stringutils import is_ascii
3047
31 import os
32 import re
33 import fnmatch
34 import cgi
35 import ujson as json
36 import urlparse
37 import itertools
38 import datetime
39 import errno
40 import shutil
41
42 import logging
4348 logger = logging.getLogger(__name__)
4449
4550
4651 class PreviewUrlResource(Resource):
4752 isLeaf = True
4853
49 def __init__(self, hs, media_repo):
54 def __init__(self, hs, media_repo, media_storage):
5055 Resource.__init__(self)
5156
5257 self.auth = hs.get_auth()
5964 self.client = SpiderHttpClient(hs)
6065 self.media_repo = media_repo
6166 self.primary_base_path = media_repo.primary_base_path
67 self.media_storage = media_storage
6268
6369 self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
6470
65 # simple memory cache mapping urls to OG metadata
66 self.cache = ExpiringCache(
71 # memory cache mapping urls to an ObservableDeferred returning
72 # JSON-encoded OG metadata
73 self._cache = ExpiringCache(
6774 cache_name="url_previews",
6875 clock=self.clock,
6976 # don't spider URLs more often than once an hour
7077 expiry_ms=60 * 60 * 1000,
7178 )
72 self.cache.start()
73
74 self.downloads = {}
79 self._cache.start()
7580
7681 self._cleaner_loop = self.clock.looping_call(
7782 self._expire_url_cache_data, 10 * 1000
7883 )
84
85 def render_OPTIONS(self, request):
86 return respond_with_json(request, 200, {}, send_cors=True)
7987
8088 def render_GET(self, request):
8189 self._async_render_GET(request)
93101 else:
94102 ts = self.clock.time_msec()
95103
104 # XXX: we could move this into _do_preview if we wanted.
96105 url_tuple = urlparse.urlsplit(url)
97106 for entry in self.url_preview_url_blacklist:
98107 match = True
125134 Codes.UNKNOWN
126135 )
127136
128 # first check the memory cache - good to handle all the clients on this
129 # HS thundering away to preview the same URL at the same time.
130 og = self.cache.get(url)
131 if og:
132 respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
133 return
134
135 # then check the URL cache in the DB (which will also provide us with
137 # the in-memory cache:
138 # * ensures that only one request is active at a time
139 # * takes load off the DB for the thundering herds
140 # * also caches any failures (unlike the DB) so we don't keep
141 # requesting the same endpoint
142
143 observable = self._cache.get(url)
144
145 if not observable:
146 download = preserve_fn(self._do_preview)(
147 url, requester.user, ts,
148 )
149 observable = ObservableDeferred(
150 download,
151 consumeErrors=True
152 )
153 self._cache[url] = observable
154 else:
155 logger.info("Returning cached response")
156
157 og = yield make_deferred_yieldable(observable.observe())
158 respond_with_json_bytes(request, 200, og, send_cors=True)
159
160 @defer.inlineCallbacks
161 def _do_preview(self, url, user, ts):
162 """Check the db, and download the URL and build a preview
163
164 Args:
165 url (str):
166 user (str):
167 ts (int):
168
169 Returns:
170 Deferred[str]: json-encoded og data
171 """
172 # check the URL cache in the DB (which will also provide us with
136173 # historical previews, if we have any)
137174 cache_result = yield self.store.get_url_cache(url, ts)
138175 if (
140177 cache_result["expires_ts"] > ts and
141178 cache_result["response_code"] / 100 == 2
142179 ):
143 respond_with_json_bytes(
144 request, 200, cache_result["og"].encode('utf-8'),
145 send_cors=True
146 )
180 defer.returnValue(cache_result["og"])
147181 return
148182
149 # Ensure only one download for a given URL is active at a time
150 download = self.downloads.get(url)
151 if download is None:
152 download = self._download_url(url, requester.user)
153 download = ObservableDeferred(
154 download,
155 consumeErrors=True
156 )
157 self.downloads[url] = download
158
159 @download.addBoth
160 def callback(media_info):
161 del self.downloads[url]
162 return media_info
163 media_info = yield download.observe()
164
165 # FIXME: we should probably update our cache now anyway, so that
166 # even if the OG calculation raises, we don't keep hammering on the
167 # remote server. For now, leave it uncached to aid debugging OG
168 # calculation problems
183 media_info = yield self._download_url(url, user)
169184
170185 logger.debug("got media_info of '%s'" % media_info)
171186
172187 if _is_media(media_info['media_type']):
188 file_id = media_info['filesystem_id']
173189 dims = yield self.media_repo._generate_thumbnails(
174 None, media_info['filesystem_id'], media_info, url_cache=True,
190 None, file_id, file_id, media_info["media_type"],
191 url_cache=True,
175192 )
176193
177194 og = {
211228 # just rely on the caching on the master request to speed things up.
212229 if 'og:image' in og and og['og:image']:
213230 image_info = yield self._download_url(
214 _rebase_url(og['og:image'], media_info['uri']), requester.user
231 _rebase_url(og['og:image'], media_info['uri']), user
215232 )
216233
217234 if _is_media(image_info['media_type']):
218235 # TODO: make sure we don't choke on white-on-transparent images
236 file_id = image_info['filesystem_id']
219237 dims = yield self.media_repo._generate_thumbnails(
220 None, image_info['filesystem_id'], image_info, url_cache=True,
238 None, file_id, file_id, image_info["media_type"],
239 url_cache=True,
221240 )
222241 if dims:
223242 og["og:image:width"] = dims['width']
238257
239258 logger.debug("Calculated OG for %s as %s" % (url, og))
240259
241 # store OG in ephemeral in-memory cache
242 self.cache[url] = og
260 jsonog = json.dumps(og)
243261
244262 # store OG in history-aware DB cache
245263 yield self.store.store_url_cache(
247265 media_info["response_code"],
248266 media_info["etag"],
249267 media_info["expires"] + media_info["created_ts"],
250 json.dumps(og),
268 jsonog,
251269 media_info["filesystem_id"],
252270 media_info["created_ts"],
253271 )
254272
255 respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
273 defer.returnValue(jsonog)
256274
257275 @defer.inlineCallbacks
258276 def _download_url(self, url, user):
262280
263281 file_id = datetime.date.today().isoformat() + '_' + random_string(16)
264282
265 fpath = self.filepaths.url_cache_filepath_rel(file_id)
266 fname = os.path.join(self.primary_base_path, fpath)
267 self.media_repo._makedirs(fname)
268
269 try:
270 with open(fname, "wb") as f:
283 file_info = FileInfo(
284 server_name=None,
285 file_id=file_id,
286 url_cache=True,
287 )
288
289 with self.media_storage.store_into_file(file_info) as (f, fname, finish):
290 try:
271291 logger.debug("Trying to get url '%s'" % url)
272292 length, headers, uri, code = yield self.client.get_file(
273293 url, output_stream=f, max_size=self.max_spider_size,
274294 )
295 except Exception as e:
275296 # FIXME: pass through 404s and other error messages nicely
276
277 yield self.media_repo.copy_to_backup(fpath)
278
279 media_type = headers["Content-Type"][0]
297 logger.warn("Error downloading %s: %r", url, e)
298 raise SynapseError(
299 500, "Failed to download content: %s" % (
300 traceback.format_exception_only(sys.exc_type, e),
301 ),
302 Codes.UNKNOWN,
303 )
304 yield finish()
305
306 try:
307 if "Content-Type" in headers:
308 media_type = headers["Content-Type"][0]
309 else:
310 media_type = "application/octet-stream"
280311 time_now_ms = self.clock.time_msec()
281312
282313 content_disposition = headers.get("Content-Disposition", None)
316347 )
317348
318349 except Exception as e:
319 os.remove(fname)
320 raise SynapseError(
321 500, ("Failed to download content: %s" % e),
322 Codes.UNKNOWN
323 )
350 logger.error("Error handling downloaded %s: %r", url, e)
351 # TODO: we really ought to delete the downloaded file in this
352 # case, since we won't have recorded it in the db, and will
353 # therefore not expire it.
354 raise
324355
325356 defer.returnValue({
326357 "media_type": media_type,
341372 def _expire_url_cache_data(self):
342373 """Clean up expired url cache content, media and thumbnails.
343374 """
344
345375 # TODO: Delete from backup media store
346376
347377 now = self.clock.time_msec()
378
379 logger.info("Running url preview cache expiry")
380
381 if not (yield self.store.has_completed_background_updates()):
382 logger.info("Still running DB updates; skipping expiry")
383 return
348384
349385 # First we delete expired url cache entries
350386 media_ids = yield self.store.get_expired_url_cache(now)
366402 dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
367403 for dir in dirs:
368404 os.rmdir(dir)
369 except:
405 except Exception:
370406 pass
371407
372408 yield self.store.delete_url_cache(removed_media)
396432 dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
397433 for dir in dirs:
398434 os.rmdir(dir)
399 except:
435 except Exception:
400436 pass
401437
402438 thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
414450 dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
415451 for dir in dirs:
416452 os.rmdir(dir)
417 except:
453 except Exception:
418454 pass
419455
420456 yield self.store.delete_url_cache_media(removed_media)
421457
422 if removed_media:
423 logger.info("Deleted %d media from url cache", len(removed_media))
458 logger.info("Deleted %d media from url cache", len(removed_media))
424459
425460
426461 def decode_and_calc_og(body, media_uri, request_encoding=None):
519554 from lxml import etree
520555
521556 TAGS_TO_REMOVE = (
522 "header", "nav", "aside", "footer", "script", "style", etree.Comment
557 "header",
558 "nav",
559 "aside",
560 "footer",
561 "script",
562 "noscript",
563 "style",
564 etree.Comment
523565 )
524566
525567 # Split all the text nodes into paragraphs (by splitting on new
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from twisted.internet import defer, threads
16
17 from .media_storage import FileResponder
18
19 from synapse.config._base import Config
20 from synapse.util.logcontext import preserve_fn
21
22 import logging
23 import os
24 import shutil
25
26
27 logger = logging.getLogger(__name__)
28
29
30 class StorageProvider(object):
31 """A storage provider is a service that can store uploaded media and
32 retrieve them.
33 """
34 def store_file(self, path, file_info):
35 """Store the file described by file_info. The actual contents can be
36 retrieved by reading the file in file_info.upload_path.
37
38 Args:
39 path (str): Relative path of file in local cache
40 file_info (FileInfo)
41
42 Returns:
43 Deferred
44 """
45 pass
46
47 def fetch(self, path, file_info):
48 """Attempt to fetch the file described by file_info and stream it
49 into writer.
50
51 Args:
52 path (str): Relative path of file in local cache
53 file_info (FileInfo)
54
55 Returns:
56 Deferred(Responder): Returns a Responder if the provider has the file,
57 otherwise returns None.
58 """
59 pass
60
61
62 class StorageProviderWrapper(StorageProvider):
63 """Wraps a storage provider and provides various config options
64
65 Args:
66 backend (StorageProvider)
67 store_local (bool): Whether to store new local files or not.
68 store_synchronous (bool): Whether to wait for file to be successfully
69 uploaded, or todo the upload in the backgroud.
70 store_remote (bool): Whether remote media should be uploaded
71 """
72 def __init__(self, backend, store_local, store_synchronous, store_remote):
73 self.backend = backend
74 self.store_local = store_local
75 self.store_synchronous = store_synchronous
76 self.store_remote = store_remote
77
78 def store_file(self, path, file_info):
79 if not file_info.server_name and not self.store_local:
80 return defer.succeed(None)
81
82 if file_info.server_name and not self.store_remote:
83 return defer.succeed(None)
84
85 if self.store_synchronous:
86 return self.backend.store_file(path, file_info)
87 else:
88 # TODO: Handle errors.
89 preserve_fn(self.backend.store_file)(path, file_info)
90 return defer.succeed(None)
91
92 def fetch(self, path, file_info):
93 return self.backend.fetch(path, file_info)
94
95
96 class FileStorageProviderBackend(StorageProvider):
97 """A storage provider that stores files in a directory on a filesystem.
98
99 Args:
100 hs (HomeServer)
101 config: The config returned by `parse_config`.
102 """
103
104 def __init__(self, hs, config):
105 self.cache_directory = hs.config.media_store_path
106 self.base_directory = config
107
108 def store_file(self, path, file_info):
109 """See StorageProvider.store_file"""
110
111 primary_fname = os.path.join(self.cache_directory, path)
112 backup_fname = os.path.join(self.base_directory, path)
113
114 dirname = os.path.dirname(backup_fname)
115 if not os.path.exists(dirname):
116 os.makedirs(dirname)
117
118 return threads.deferToThread(
119 shutil.copyfile, primary_fname, backup_fname,
120 )
121
122 def fetch(self, path, file_info):
123 """See StorageProvider.fetch"""
124
125 backup_fname = os.path.join(self.base_directory, path)
126 if os.path.isfile(backup_fname):
127 return FileResponder(open(backup_fname, "rb"))
128
129 @staticmethod
130 def parse_config(config):
131 """Called on startup to parse config supplied. This should parse
132 the config and raise if there is a problem.
133
134 The returned value is passed into the constructor.
135
136 In this case we only care about a single param, the directory, so let's
137 just pull that out.
138 """
139 return Config.ensure_directory(config["directory"])
1313 # limitations under the License.
1414
1515
16 from ._base import parse_media_id, respond_404, respond_with_file
16 from ._base import (
17 parse_media_id, respond_404, respond_with_file, FileInfo,
18 respond_with_responder,
19 )
1720 from twisted.web.resource import Resource
1821 from synapse.http.servlet import parse_string, parse_integer
1922 from synapse.http.server import request_handler, set_cors_headers
2932 class ThumbnailResource(Resource):
3033 isLeaf = True
3134
32 def __init__(self, hs, media_repo):
35 def __init__(self, hs, media_repo, media_storage):
3336 Resource.__init__(self)
3437
3538 self.store = hs.get_datastore()
36 self.filepaths = media_repo.filepaths
3739 self.media_repo = media_repo
40 self.media_storage = media_storage
3841 self.dynamic_thumbnails = hs.config.dynamic_thumbnails
3942 self.server_name = hs.hostname
4043 self.version_string = hs.version_string
6366 yield self._respond_local_thumbnail(
6467 request, media_id, width, height, method, m_type
6568 )
69 self.media_repo.mark_recently_accessed(None, media_id)
6670 else:
6771 if self.dynamic_thumbnails:
6872 yield self._select_or_generate_remote_thumbnail(
7478 request, server_name, media_id,
7579 width, height, method, m_type
7680 )
81 self.media_repo.mark_recently_accessed(server_name, media_id)
7782
7883 @defer.inlineCallbacks
7984 def _respond_local_thumbnail(self, request, media_id, width, height,
8085 method, m_type):
8186 media_info = yield self.store.get_local_media(media_id)
8287
83 if not media_info or media_info["quarantined_by"]:
88 if not media_info:
8489 respond_404(request)
8590 return
86
87 # if media_info["media_type"] == "image/svg+xml":
88 # file_path = self.filepaths.local_media_filepath(media_id)
89 # yield respond_with_file(request, media_info["media_type"], file_path)
90 # return
91 if media_info["quarantined_by"]:
92 logger.info("Media is quarantined")
93 respond_404(request)
94 return
9195
9296 thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
9397
9599 thumbnail_info = self._select_thumbnail(
96100 width, height, method, m_type, thumbnail_infos
97101 )
98 t_width = thumbnail_info["thumbnail_width"]
99 t_height = thumbnail_info["thumbnail_height"]
100 t_type = thumbnail_info["thumbnail_type"]
101 t_method = thumbnail_info["thumbnail_method"]
102
103 if media_info["url_cache"]:
104 # TODO: Check the file still exists, if it doesn't we can redownload
105 # it from the url `media_info["url_cache"]`
106 file_path = self.filepaths.url_cache_thumbnail(
107 media_id, t_width, t_height, t_type, t_method,
108 )
109 else:
110 file_path = self.filepaths.local_media_thumbnail(
111 media_id, t_width, t_height, t_type, t_method,
112 )
113 yield respond_with_file(request, t_type, file_path)
114
115 else:
116 yield self._respond_default_thumbnail(
117 request, media_info, width, height, method, m_type,
102
103 file_info = FileInfo(
104 server_name=None, file_id=media_id,
105 url_cache=media_info["url_cache"],
106 thumbnail=True,
107 thumbnail_width=thumbnail_info["thumbnail_width"],
108 thumbnail_height=thumbnail_info["thumbnail_height"],
109 thumbnail_type=thumbnail_info["thumbnail_type"],
110 thumbnail_method=thumbnail_info["thumbnail_method"],
118111 )
112
113 t_type = file_info.thumbnail_type
114 t_length = thumbnail_info["thumbnail_length"]
115
116 responder = yield self.media_storage.fetch_media(file_info)
117 yield respond_with_responder(request, responder, t_type, t_length)
118 else:
119 logger.info("Couldn't find any generated thumbnails")
120 respond_404(request)
119121
120122 @defer.inlineCallbacks
121123 def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
123125 desired_type):
124126 media_info = yield self.store.get_local_media(media_id)
125127
126 if not media_info or media_info["quarantined_by"]:
128 if not media_info:
127129 respond_404(request)
128130 return
129
130 # if media_info["media_type"] == "image/svg+xml":
131 # file_path = self.filepaths.local_media_filepath(media_id)
132 # yield respond_with_file(request, media_info["media_type"], file_path)
133 # return
131 if media_info["quarantined_by"]:
132 logger.info("Media is quarantined")
133 respond_404(request)
134 return
134135
135136 thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
136137 for info in thumbnail_infos:
140141 t_type = info["thumbnail_type"] == desired_type
141142
142143 if t_w and t_h and t_method and t_type:
143 if media_info["url_cache"]:
144 # TODO: Check the file still exists, if it doesn't we can redownload
145 # it from the url `media_info["url_cache"]`
146 file_path = self.filepaths.url_cache_thumbnail(
147 media_id, desired_width, desired_height, desired_type,
148 desired_method,
149 )
150 else:
151 file_path = self.filepaths.local_media_thumbnail(
152 media_id, desired_width, desired_height, desired_type,
153 desired_method,
154 )
155 yield respond_with_file(request, desired_type, file_path)
156 return
157
158 logger.debug("We don't have a local thumbnail of that size. Generating")
144 file_info = FileInfo(
145 server_name=None, file_id=media_id,
146 url_cache=media_info["url_cache"],
147 thumbnail=True,
148 thumbnail_width=info["thumbnail_width"],
149 thumbnail_height=info["thumbnail_height"],
150 thumbnail_type=info["thumbnail_type"],
151 thumbnail_method=info["thumbnail_method"],
152 )
153
154 t_type = file_info.thumbnail_type
155 t_length = info["thumbnail_length"]
156
157 responder = yield self.media_storage.fetch_media(file_info)
158 if responder:
159 yield respond_with_responder(request, responder, t_type, t_length)
160 return
161
162 logger.debug("We don't have a thumbnail of that size. Generating")
159163
160164 # Okay, so we generate one.
161165 file_path = yield self.media_repo.generate_local_exact_thumbnail(
162 media_id, desired_width, desired_height, desired_method, desired_type
166 media_id, desired_width, desired_height, desired_method, desired_type,
167 url_cache=media_info["url_cache"],
163168 )
164169
165170 if file_path:
166171 yield respond_with_file(request, desired_type, file_path)
167172 else:
168 yield self._respond_default_thumbnail(
169 request, media_info, desired_width, desired_height,
170 desired_method, desired_type,
171 )
173 logger.warn("Failed to generate thumbnail")
174 respond_404(request)
172175
173176 @defer.inlineCallbacks
174177 def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
175178 desired_width, desired_height,
176179 desired_method, desired_type):
177 media_info = yield self.media_repo.get_remote_media(server_name, media_id)
178
179 # if media_info["media_type"] == "image/svg+xml":
180 # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
181 # yield respond_with_file(request, media_info["media_type"], file_path)
182 # return
180 media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
183181
184182 thumbnail_infos = yield self.store.get_remote_media_thumbnails(
185183 server_name, media_id,
194192 t_type = info["thumbnail_type"] == desired_type
195193
196194 if t_w and t_h and t_method and t_type:
197 file_path = self.filepaths.remote_media_thumbnail(
198 server_name, file_id, desired_width, desired_height,
199 desired_type, desired_method,
200 )
201 yield respond_with_file(request, desired_type, file_path)
202 return
203
204 logger.debug("We don't have a local thumbnail of that size. Generating")
195 file_info = FileInfo(
196 server_name=server_name, file_id=media_info["filesystem_id"],
197 thumbnail=True,
198 thumbnail_width=info["thumbnail_width"],
199 thumbnail_height=info["thumbnail_height"],
200 thumbnail_type=info["thumbnail_type"],
201 thumbnail_method=info["thumbnail_method"],
202 )
203
204 t_type = file_info.thumbnail_type
205 t_length = info["thumbnail_length"]
206
207 responder = yield self.media_storage.fetch_media(file_info)
208 if responder:
209 yield respond_with_responder(request, responder, t_type, t_length)
210 return
211
212 logger.debug("We don't have a thumbnail of that size. Generating")
205213
206214 # Okay, so we generate one.
207215 file_path = yield self.media_repo.generate_remote_exact_thumbnail(
212220 if file_path:
213221 yield respond_with_file(request, desired_type, file_path)
214222 else:
215 yield self._respond_default_thumbnail(
216 request, media_info, desired_width, desired_height,
217 desired_method, desired_type,
218 )
223 logger.warn("Failed to generate thumbnail")
224 respond_404(request)
219225
220226 @defer.inlineCallbacks
221227 def _respond_remote_thumbnail(self, request, server_name, media_id, width,
222228 height, method, m_type):
223229 # TODO: Don't download the whole remote file
224 # We should proxy the thumbnail from the remote server instead.
225 media_info = yield self.media_repo.get_remote_media(server_name, media_id)
226
227 # if media_info["media_type"] == "image/svg+xml":
228 # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
229 # yield respond_with_file(request, media_info["media_type"], file_path)
230 # return
230 # We should proxy the thumbnail from the remote server instead of
231 # downloading the remote file and generating our own thumbnails.
232 media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
231233
232234 thumbnail_infos = yield self.store.get_remote_media_thumbnails(
233235 server_name, media_id,
237239 thumbnail_info = self._select_thumbnail(
238240 width, height, method, m_type, thumbnail_infos
239241 )
240 t_width = thumbnail_info["thumbnail_width"]
241 t_height = thumbnail_info["thumbnail_height"]
242 t_type = thumbnail_info["thumbnail_type"]
243 t_method = thumbnail_info["thumbnail_method"]
244 file_id = thumbnail_info["filesystem_id"]
242 file_info = FileInfo(
243 server_name=server_name, file_id=media_info["filesystem_id"],
244 thumbnail=True,
245 thumbnail_width=thumbnail_info["thumbnail_width"],
246 thumbnail_height=thumbnail_info["thumbnail_height"],
247 thumbnail_type=thumbnail_info["thumbnail_type"],
248 thumbnail_method=thumbnail_info["thumbnail_method"],
249 )
250
251 t_type = file_info.thumbnail_type
245252 t_length = thumbnail_info["thumbnail_length"]
246253
247 file_path = self.filepaths.remote_media_thumbnail(
248 server_name, file_id, t_width, t_height, t_type, t_method,
249 )
250 yield respond_with_file(request, t_type, file_path, t_length)
251 else:
252 yield self._respond_default_thumbnail(
253 request, media_info, width, height, method, m_type,
254 )
255
256 @defer.inlineCallbacks
257 def _respond_default_thumbnail(self, request, media_info, width, height,
258 method, m_type):
259 # XXX: how is this meant to work? store.get_default_thumbnails
260 # appears to always return [] so won't this always 404?
261 media_type = media_info["media_type"]
262 top_level_type = media_type.split("/")[0]
263 sub_type = media_type.split("/")[-1].split(";")[0]
264 thumbnail_infos = yield self.store.get_default_thumbnails(
265 top_level_type, sub_type,
266 )
267 if not thumbnail_infos:
268 thumbnail_infos = yield self.store.get_default_thumbnails(
269 top_level_type, "_default",
270 )
271 if not thumbnail_infos:
272 thumbnail_infos = yield self.store.get_default_thumbnails(
273 "_default", "_default",
274 )
275 if not thumbnail_infos:
276 respond_404(request)
277 return
278
279 thumbnail_info = self._select_thumbnail(
280 width, height, "crop", m_type, thumbnail_infos
281 )
282
283 t_width = thumbnail_info["thumbnail_width"]
284 t_height = thumbnail_info["thumbnail_height"]
285 t_type = thumbnail_info["thumbnail_type"]
286 t_method = thumbnail_info["thumbnail_method"]
287 t_length = thumbnail_info["thumbnail_length"]
288
289 file_path = self.filepaths.default_thumbnail(
290 top_level_type, sub_type, t_width, t_height, t_type, t_method,
291 )
292 yield respond_with_file(request, t_type, file_path, t_length)
254 responder = yield self.media_storage.fetch_media(file_info)
255 yield respond_with_responder(request, responder, t_type, t_length)
256 else:
257 logger.info("Failed to find any generated thumbnails")
258 respond_404(request)
293259
294260 def _select_thumbnail(self, desired_width, desired_height, desired_method,
295261 desired_type, thumbnail_infos):
3131 from synapse.crypto.keyring import Keyring
3232 from synapse.events.builder import EventBuilderFactory
3333 from synapse.events.spamcheck import SpamChecker
34 from synapse.federation import initialize_http_replication
34 from synapse.federation.federation_client import FederationClient
35 from synapse.federation.federation_server import FederationServer
3536 from synapse.federation.send_queue import FederationRemoteSendQueue
37 from synapse.federation.federation_server import FederationHandlerRegistry
3638 from synapse.federation.transport.client import TransportLayerClient
3739 from synapse.federation.transaction_queue import TransactionQueue
3840 from synapse.handlers import Handlers
3941 from synapse.handlers.appservice import ApplicationServicesHandler
4042 from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
43 from synapse.handlers.deactivate_account import DeactivateAccountHandler
4144 from synapse.handlers.devicemessage import DeviceMessageHandler
4245 from synapse.handlers.device import DeviceHandler
4346 from synapse.handlers.e2e_keys import E2eKeysHandler
4447 from synapse.handlers.presence import PresenceHandler
4548 from synapse.handlers.room_list import RoomListHandler
49 from synapse.handlers.room_member import RoomMemberMasterHandler
50 from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
51 from synapse.handlers.set_password import SetPasswordHandler
4652 from synapse.handlers.sync import SyncHandler
4753 from synapse.handlers.typing import TypingHandler
4854 from synapse.handlers.events import EventHandler, EventStreamHandler
4955 from synapse.handlers.initial_sync import InitialSyncHandler
5056 from synapse.handlers.receipts import ReceiptsHandler
5157 from synapse.handlers.read_marker import ReadMarkerHandler
52 from synapse.handlers.user_directory import UserDirectoyHandler
58 from synapse.handlers.user_directory import UserDirectoryHandler
5359 from synapse.handlers.groups_local import GroupsLocalHandler
5460 from synapse.handlers.profile import ProfileHandler
61 from synapse.handlers.message import EventCreationHandler
5562 from synapse.groups.groups_server import GroupsServerHandler
5663 from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
5764 from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
5966 from synapse.notifier import Notifier
6067 from synapse.push.action_generator import ActionGenerator
6168 from synapse.push.pusherpool import PusherPool
62 from synapse.rest.media.v1.media_repository import MediaRepository
63 from synapse.state import StateHandler
69 from synapse.rest.media.v1.media_repository import (
70 MediaRepository,
71 MediaRepositoryResource,
72 )
73 from synapse.state import StateHandler, StateResolutionHandler
6474 from synapse.storage import DataStore
6575 from synapse.streams.events import EventSources
6676 from synapse.util import Clock
8999 """
90100
91101 DEPENDENCIES = [
92 'config',
93 'clock',
94102 'http_client',
95103 'db_pool',
96 'persistence_service',
97 'replication_layer',
98 'datastore',
104 'federation_client',
105 'federation_server',
99106 'handlers',
100107 'v1auth',
101108 'auth',
102 'rest_servlet_factory',
103109 'state_handler',
110 'state_resolution_handler',
104111 'presence_handler',
105112 'sync_handler',
106113 'typing_handler',
116123 'application_service_handler',
117124 'device_message_handler',
118125 'profile_handler',
126 'event_creation_handler',
127 'deactivate_account_handler',
128 'set_password_handler',
119129 'notifier',
120 'distributor',
121 'client_resource',
122 'resource_for_federation',
123 'resource_for_static_content',
124 'resource_for_web_client',
125 'resource_for_content_repo',
126 'resource_for_server_key',
127 'resource_for_server_key_v2',
128 'resource_for_media_repository',
129 'resource_for_metrics',
130130 'event_sources',
131 'ratelimiter',
132131 'keyring',
133132 'pusherpool',
134133 'event_builder_factory',
136135 'http_client_context_factory',
137136 'simple_http_client',
138137 'media_repository',
138 'media_repository_resource',
139139 'federation_transport_client',
140140 'federation_sender',
141141 'receipts_handler',
149149 'groups_attestation_signing',
150150 'groups_attestation_renewer',
151151 'spam_checker',
152 'room_member_handler',
153 'federation_registry',
152154 ]
153155
154156 def __init__(self, hostname, **kwargs):
182184 def is_mine_id(self, string):
183185 return string.split(":", 1)[1] == self.hostname
184186
185 def build_replication_layer(self):
186 return initialize_http_replication(self)
187 def get_clock(self):
188 return self.clock
189
190 def get_datastore(self):
191 return self.datastore
192
193 def get_config(self):
194 return self.config
195
196 def get_distributor(self):
197 return self.distributor
198
199 def get_ratelimiter(self):
200 return self.ratelimiter
201
202 def build_federation_client(self):
203 return FederationClient(self)
204
205 def build_federation_server(self):
206 return FederationServer(self)
187207
188208 def build_handlers(self):
189209 return Handlers(self)
216236 def build_state_handler(self):
217237 return StateHandler(self)
218238
239 def build_state_resolution_handler(self):
240 return StateResolutionHandler(self)
241
219242 def build_presence_handler(self):
220243 return PresenceHandler(self)
221244
263286
264287 def build_profile_handler(self):
265288 return ProfileHandler(self)
289
290 def build_event_creation_handler(self):
291 return EventCreationHandler(self)
292
293 def build_deactivate_account_handler(self):
294 return DeactivateAccountHandler(self)
295
296 def build_set_password_handler(self):
297 return SetPasswordHandler(self)
266298
267299 def build_event_sources(self):
268300 return EventSources(self)
292324 name,
293325 **self.db_config.get("args", {})
294326 )
327
328 def get_db_conn(self, run_new_connection=True):
329 """Makes a new connection to the database, skipping the db pool
330
331 Returns:
332 Connection: a connection object implementing the PEP-249 spec
333 """
334 # Any param beginning with cp_ is a parameter for adbapi, and should
335 # not be passed to the database engine.
336 db_params = {
337 k: v for k, v in self.db_config.get("args", {}).items()
338 if not k.startswith("cp_")
339 }
340 db_conn = self.database_engine.module.connect(**db_params)
341 if run_new_connection:
342 self.database_engine.on_new_connection(db_conn)
343 return db_conn
344
345 def build_media_repository_resource(self):
346 # build the media repo resource. This indirects through the HomeServer
347 # to ensure that we only have a single instance of
348 return MediaRepositoryResource(self)
295349
296350 def build_media_repository(self):
297351 return MediaRepository(self)
320374 return ActionGenerator(self)
321375
322376 def build_user_directory_handler(self):
323 return UserDirectoyHandler(self)
377 return UserDirectoryHandler(self)
324378
325379 def build_groups_local_handler(self):
326380 return GroupsLocalHandler(self)
336390
337391 def build_spam_checker(self):
338392 return SpamChecker(self)
393
394 def build_room_member_handler(self):
395 if self.config.worker_app:
396 return RoomMemberWorkerHandler(self)
397 return RoomMemberMasterHandler(self)
398
399 def build_federation_registry(self):
400 return FederationHandlerRegistry()
339401
340402 def remove_pusher(self, app_id, push_key, user_id):
341403 return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
22 import synapse.federation.transport.client
33 import synapse.handlers
44 import synapse.handlers.auth
5 import synapse.handlers.deactivate_account
56 import synapse.handlers.device
67 import synapse.handlers.e2e_keys
8 import synapse.handlers.set_password
9 import synapse.rest.media.v1.media_repository
10 import synapse.state
711 import synapse.storage
8 import synapse.state
12
913
1014 class HomeServer(object):
1115 def get_auth(self) -> synapse.api.auth.Auth:
2933 def get_state_handler(self) -> synapse.state.StateHandler:
3034 pass
3135
36 def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
37 pass
38
39 def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
40 pass
41
42 def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
43 pass
44
3245 def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
3346 pass
3447
3548 def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
3649 pass
50
51 def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
52 pass
53
54 def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
55 pass
5757 __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
5858
5959 def __init__(self, state, state_group, prev_group=None, delta_ids=None):
60 # dict[(str, str), str] map from (type, state_key) to event_id
6061 self.state = frozendict(state)
62
63 # the ID of a state group if one and only one is involved.
64 # otherwise, None otherwise?
6165 self.state_group = state_group
6266
6367 self.prev_group = prev_group
8084
8185
8286 class StateHandler(object):
83 """ Responsible for doing state conflict resolution.
87 """Fetches bits of state from the stores, and does state resolution
88 where necessary
8489 """
8590
8691 def __init__(self, hs):
8792 self.clock = hs.get_clock()
8893 self.store = hs.get_datastore()
8994 self.hs = hs
90
91 # dict of set of event_ids -> _StateCacheEntry.
92 self._state_cache = None
93 self.resolve_linearizer = Linearizer(name="state_resolve_lock")
95 self._state_resolution_handler = hs.get_state_resolution_handler()
9496
9597 def start_caching(self):
96 logger.debug("start_caching")
97
98 self._state_cache = ExpiringCache(
99 cache_name="state_cache",
100 clock=self.clock,
101 max_len=SIZE_OF_CACHE,
102 expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
103 iterable=True,
104 reset_expiry_on_get=True,
105 )
106
107 self._state_cache.start()
98 # TODO: remove this shim
99 self._state_resolution_handler.start_caching()
108100
109101 @defer.inlineCallbacks
110102 def get_current_state(self, room_id, event_type=None, state_key="",
126118 latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
127119
128120 logger.debug("calling resolve_state_groups from get_current_state")
129 ret = yield self.resolve_state_groups(room_id, latest_event_ids)
121 ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
130122 state = ret.state
131123
132124 if event_type:
145137 defer.returnValue(state)
146138
147139 @defer.inlineCallbacks
148 def get_current_state_ids(self, room_id, event_type=None, state_key="",
149 latest_event_ids=None):
140 def get_current_state_ids(self, room_id, latest_event_ids=None):
141 """Get the current state, or the state at a set of events, for a room
142
143 Args:
144 room_id (str):
145
146 latest_event_ids (iterable[str]|None): if given, the forward
147 extremities to resolve. If None, we look them up from the
148 database (via a cache)
149
150 Returns:
151 Deferred[dict[(str, str), str)]]: the state dict, mapping from
152 (event_type, state_key) -> event_id
153 """
150154 if not latest_event_ids:
151155 latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
152156
153157 logger.debug("calling resolve_state_groups from get_current_state_ids")
154 ret = yield self.resolve_state_groups(room_id, latest_event_ids)
158 ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
155159 state = ret.state
156
157 if event_type:
158 defer.returnValue(state.get((event_type, state_key)))
159 return
160160
161161 defer.returnValue(state)
162162
165165 if not latest_event_ids:
166166 latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
167167 logger.debug("calling resolve_state_groups from get_current_user_in_room")
168 entry = yield self.resolve_state_groups(room_id, latest_event_ids)
168 entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
169169 joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
170170 defer.returnValue(joined_users)
171171
174174 if not latest_event_ids:
175175 latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
176176 logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
177 entry = yield self.resolve_state_groups(room_id, latest_event_ids)
177 entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
178178 joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
179179 defer.returnValue(joined_hosts)
180180
182182 def compute_event_context(self, event, old_state=None):
183183 """Build an EventContext structure for the event.
184184
185 This works out what the current state should be for the event, and
186 generates a new state group if necessary.
187
185188 Args:
186189 event (synapse.events.EventBase):
190 old_state (dict|None): The state at the event if it can't be
191 calculated from existing events. This is normally only specified
192 when receiving an event from federation where we don't have the
193 prev events for, e.g. when backfilling.
187194 Returns:
188195 synapse.events.snapshot.EventContext:
189196 """
207214 context.current_state_ids = {}
208215 context.prev_state_ids = {}
209216 context.prev_state_events = []
210 context.state_group = self.store.get_next_state_group()
217
218 # We don't store state for outliers, so we don't generate a state
219 # froup for it.
220 context.state_group = None
221
211222 defer.returnValue(context)
212223
213224 if old_state:
225 # We already have the state, so we don't need to calculate it.
226 # Let's just correctly fill out the context and create a
227 # new state group for it.
228
214229 context = EventContext()
215230 context.prev_state_ids = {
216231 (s.type, s.state_key): s.event_id for s in old_state
217232 }
218 context.state_group = self.store.get_next_state_group()
219233
220234 if event.is_state():
221235 key = (event.type, event.state_key)
228242 else:
229243 context.current_state_ids = context.prev_state_ids
230244
245 context.state_group = yield self.store.store_state_group(
246 event.event_id,
247 event.room_id,
248 prev_group=None,
249 delta_ids=None,
250 current_state_ids=context.current_state_ids,
251 )
252
231253 context.prev_state_events = []
232254 defer.returnValue(context)
233255
234256 logger.debug("calling resolve_state_groups from compute_event_context")
235 entry = yield self.resolve_state_groups(
257 entry = yield self.resolve_state_groups_for_events(
236258 event.room_id, [e for e, _ in event.prev_events],
237259 )
238260
241263 context = EventContext()
242264 context.prev_state_ids = curr_state
243265 if event.is_state():
244 context.state_group = self.store.get_next_state_group()
266 # If this is a state event then we need to create a new state
267 # group for the state after this event.
245268
246269 key = (event.type, event.state_key)
247270 if key in context.prev_state_ids:
252275 context.current_state_ids[key] = event.event_id
253276
254277 if entry.state_group:
278 # If the state at the event has a state group assigned then
279 # we can use that as the prev group
255280 context.prev_group = entry.state_group
256281 context.delta_ids = {
257282 key: event.event_id
258283 }
259284 elif entry.prev_group:
285 # If the state at the event only has a prev group, then we can
286 # use that as a prev group too.
260287 context.prev_group = entry.prev_group
261288 context.delta_ids = dict(entry.delta_ids)
262289 context.delta_ids[key] = event.event_id
290
291 context.state_group = yield self.store.store_state_group(
292 event.event_id,
293 event.room_id,
294 prev_group=context.prev_group,
295 delta_ids=context.delta_ids,
296 current_state_ids=context.current_state_ids,
297 )
263298 else:
264 if entry.state_group is None:
265 entry.state_group = self.store.get_next_state_group()
266 entry.state_id = entry.state_group
267
268 context.state_group = entry.state_group
269299 context.current_state_ids = context.prev_state_ids
270300 context.prev_group = entry.prev_group
271301 context.delta_ids = entry.delta_ids
272302
303 if entry.state_group is None:
304 entry.state_group = yield self.store.store_state_group(
305 event.event_id,
306 event.room_id,
307 prev_group=entry.prev_group,
308 delta_ids=entry.delta_ids,
309 current_state_ids=context.current_state_ids,
310 )
311 entry.state_id = entry.state_group
312
313 context.state_group = entry.state_group
314
273315 context.prev_state_events = []
274316 defer.returnValue(context)
275317
276318 @defer.inlineCallbacks
277 @log_function
278 def resolve_state_groups(self, room_id, event_ids):
319 def resolve_state_groups_for_events(self, room_id, event_ids):
279320 """ Given a list of event_ids this method fetches the state at each
280321 event, resolves conflicts between them and returns them.
281322
323 Args:
324 room_id (str):
325 event_ids (list[str]):
326
282327 Returns:
283 a Deferred tuple of (`state_group`, `state`, `prev_state`).
284 `state_group` is the name of a state group if one and only one is
285 involved. `state` is a map from (type, state_key) to event, and
286 `prev_state` is a list of event ids.
328 Deferred[_StateCacheEntry]: resolved state
287329 """
288330 logger.debug("resolve_state_groups event_ids %s", event_ids)
289331
294336 room_id, event_ids
295337 )
296338
297 logger.debug(
298 "resolve_state_groups state_groups %s",
299 state_groups_ids.keys()
300 )
301
302 group_names = frozenset(state_groups_ids.keys())
303 if len(group_names) == 1:
339 if len(state_groups_ids) == 1:
304340 name, state_list = state_groups_ids.items().pop()
305341
306342 prev_group, delta_ids = yield self.store.get_state_group_delta(name)
311347 prev_group=prev_group,
312348 delta_ids=delta_ids,
313349 ))
350
351 result = yield self._state_resolution_handler.resolve_state_groups(
352 room_id, state_groups_ids, None, self._state_map_factory,
353 )
354 defer.returnValue(result)
355
356 def _state_map_factory(self, ev_ids):
357 return self.store.get_events(
358 ev_ids, get_prev_content=False, check_redacted=False,
359 )
360
361 def resolve_events(self, state_sets, event):
362 logger.info(
363 "Resolving state for %s with %d groups", event.room_id, len(state_sets)
364 )
365 state_set_ids = [{
366 (ev.type, ev.state_key): ev.event_id
367 for ev in st
368 } for st in state_sets]
369
370 state_map = {
371 ev.event_id: ev
372 for st in state_sets
373 for ev in st
374 }
375
376 with Measure(self.clock, "state._resolve_events"):
377 new_state = resolve_events_with_state_map(state_set_ids, state_map)
378
379 new_state = {
380 key: state_map[ev_id] for key, ev_id in new_state.items()
381 }
382
383 return new_state
384
385
386 class StateResolutionHandler(object):
387 """Responsible for doing state conflict resolution.
388
389 Note that the storage layer depends on this handler, so all functions must
390 be storage-independent.
391 """
392 def __init__(self, hs):
393 self.clock = hs.get_clock()
394
395 # dict of set of event_ids -> _StateCacheEntry.
396 self._state_cache = None
397 self.resolve_linearizer = Linearizer(name="state_resolve_lock")
398
399 def start_caching(self):
400 logger.debug("start_caching")
401
402 self._state_cache = ExpiringCache(
403 cache_name="state_cache",
404 clock=self.clock,
405 max_len=SIZE_OF_CACHE,
406 expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
407 iterable=True,
408 reset_expiry_on_get=True,
409 )
410
411 self._state_cache.start()
412
413 @defer.inlineCallbacks
414 @log_function
415 def resolve_state_groups(
416 self, room_id, state_groups_ids, event_map, state_map_factory,
417 ):
418 """Resolves conflicts between a set of state groups
419
420 Always generates a new state group (unless we hit the cache), so should
421 not be called for a single state group
422
423 Args:
424 room_id (str): room we are resolving for (used for logging)
425 state_groups_ids (dict[int, dict[(str, str), str]]):
426 map from state group id to the state in that state group
427 (where 'state' is a map from state key to event id)
428
429 event_map(dict[str,FrozenEvent]|None):
430 a dict from event_id to event, for any events that we happen to
431 have in flight (eg, those currently being persisted). This will be
432 used as a starting point fof finding the state we need; any missing
433 events will be requested via state_map_factory.
434
435 If None, all events will be fetched via state_map_factory.
436
437 Returns:
438 Deferred[_StateCacheEntry]: resolved state
439 """
440 logger.debug(
441 "resolve_state_groups state_groups %s",
442 state_groups_ids.keys()
443 )
444
445 group_names = frozenset(state_groups_ids.keys())
314446
315447 with (yield self.resolve_linearizer.queue(group_names)):
316448 if self._state_cache is not None:
340472 if conflicted_state:
341473 logger.info("Resolving conflicted state for %r", room_id)
342474 with Measure(self.clock, "state._resolve_events"):
343 new_state = yield resolve_events(
475 new_state = yield resolve_events_with_factory(
344476 state_groups_ids.values(),
345 state_map_factory=lambda ev_ids: self.store.get_events(
346 ev_ids, get_prev_content=False, check_redacted=False,
347 ),
477 event_map=event_map,
478 state_map_factory=state_map_factory,
348479 )
349480 else:
350481 new_state = {
351482 key: e_ids.pop() for key, e_ids in state.items()
352483 }
353484
485 # if the new state matches any of the input state groups, we can
486 # use that state group again. Otherwise we will generate a state_id
487 # which will be used as a cache key for future resolutions, but
488 # not get persisted.
354489 state_group = None
355490 new_state_event_ids = frozenset(new_state.values())
356491 for sg, events in state_groups_ids.items():
387522
388523 defer.returnValue(cache)
389524
390 def resolve_events(self, state_sets, event):
391 logger.info(
392 "Resolving state for %s with %d groups", event.room_id, len(state_sets)
393 )
394 state_set_ids = [{
395 (ev.type, ev.state_key): ev.event_id
396 for ev in st
397 } for st in state_sets]
398
399 state_map = {
400 ev.event_id: ev
401 for st in state_sets
402 for ev in st
403 }
404
405 with Measure(self.clock, "state._resolve_events"):
406 new_state = resolve_events(state_set_ids, state_map)
407
408 new_state = {
409 key: state_map[ev_id] for key, ev_id in new_state.items()
410 }
411
412 return new_state
413
414525
415526 def _ordered_events(events):
416527 def key_func(e):
419530 return sorted(events, key=key_func)
420531
421532
422 def resolve_events(state_sets, state_map_factory):
533 def resolve_events_with_state_map(state_sets, state_map):
423534 """
424535 Args:
425536 state_sets(list): List of dicts of (type, state_key) -> event_id,
426537 which are the different state groups to resolve.
427 state_map_factory(dict|callable): If callable, then will be called
428 with a list of event_ids that are needed, and should return with
429 a Deferred of dict of event_id to event. Otherwise, should be
430 a dict from event_id to event of all events in state_sets.
538 state_map(dict): a dict from event_id to event, for all events in
539 state_sets.
431540
432541 Returns
433 dict[(str, str), synapse.events.FrozenEvent] is a map from
434 (type, state_key) to event.
542 dict[(str, str), str]:
543 a map from (type, state_key) to event_id.
435544 """
436545 if len(state_sets) == 1:
437546 return state_sets[0]
439548 unconflicted_state, conflicted_state = _seperate(
440549 state_sets,
441550 )
442
443 if callable(state_map_factory):
444 return _resolve_with_state_fac(
445 unconflicted_state, conflicted_state, state_map_factory
446 )
447
448 state_map = state_map_factory
449551
450552 auth_events = _create_auth_events_from_maps(
451553 unconflicted_state, conflicted_state, state_map
460562 """Takes the state_sets and figures out which keys are conflicted and
461563 which aren't. i.e., which have multiple different event_ids associated
462564 with them in different state sets.
565
566 Args:
567 state_sets(list[dict[(str, str), str]]):
568 List of dicts of (type, state_key) -> event_id, which are the
569 different state groups to resolve.
570
571 Returns:
572 (dict[(str, str), str], dict[(str, str), set[str]]):
573 A tuple of (unconflicted_state, conflicted_state), where:
574
575 unconflicted_state is a dict mapping (type, state_key)->event_id
576 for unconflicted state keys.
577
578 conflicted_state is a dict mapping (type, state_key) to a set of
579 event ids for conflicted state keys.
463580 """
464581 unconflicted_state = dict(state_sets[0])
465582 conflicted_state = {}
490607
491608
492609 @defer.inlineCallbacks
493 def _resolve_with_state_fac(unconflicted_state, conflicted_state,
494 state_map_factory):
610 def resolve_events_with_factory(state_sets, event_map, state_map_factory):
611 """
612 Args:
613 state_sets(list): List of dicts of (type, state_key) -> event_id,
614 which are the different state groups to resolve.
615
616 event_map(dict[str,FrozenEvent]|None):
617 a dict from event_id to event, for any events that we happen to
618 have in flight (eg, those currently being persisted). This will be
619 used as a starting point fof finding the state we need; any missing
620 events will be requested via state_map_factory.
621
622 If None, all events will be fetched via state_map_factory.
623
624 state_map_factory(func): will be called
625 with a list of event_ids that are needed, and should return with
626 a Deferred of dict of event_id to event.
627
628 Returns
629 Deferred[dict[(str, str), str]]:
630 a map from (type, state_key) to event_id.
631 """
632 if len(state_sets) == 1:
633 defer.returnValue(state_sets[0])
634
635 unconflicted_state, conflicted_state = _seperate(
636 state_sets,
637 )
638
495639 needed_events = set(
496640 event_id
497641 for event_ids in conflicted_state.itervalues()
498642 for event_id in event_ids
499643 )
644 if event_map is not None:
645 needed_events -= set(event_map.iterkeys())
500646
501647 logger.info("Asking for %d conflicted events", len(needed_events))
502648
503649 # dict[str, FrozenEvent]: a map from state event id to event. Only includes
504 # the state events which are in conflict.
650 # the state events which are in conflict (and those in event_map)
505651 state_map = yield state_map_factory(needed_events)
652 if event_map is not None:
653 state_map.update(event_map)
506654
507655 # get the ids of the auth events which allow us to authenticate the
508656 # conflicted state, picking only from the unconflicting state.
514662
515663 new_needed_events = set(auth_events.itervalues())
516664 new_needed_events -= needed_events
665 if event_map is not None:
666 new_needed_events -= set(event_map.iterkeys())
517667
518668 logger.info("Asking for %d auth events", len(new_needed_events))
519669
559709 resolved_state = _resolve_state_events(
560710 conflicted_state, auth_events
561711 )
562 except:
712 except Exception:
563713 logger.exception("Failed to resolve state")
564714 raise
565715
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1819 from .appservice import (
1920 ApplicationServiceStore, ApplicationServiceTransactionStore
2021 )
21 from ._base import LoggingTransaction
2222 from .directory import DirectoryStore
2323 from .events import EventsStore
2424 from .presence import PresenceStore, UserPresenceState
103103 db_conn, "events", "stream_ordering", step=-1,
104104 extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
105105 )
106 self._receipts_id_gen = StreamIdGenerator(
107 db_conn, "receipts_linearized", "stream_id"
108 )
109 self._account_data_id_gen = StreamIdGenerator(
110 db_conn, "account_data_max_stream_id", "stream_id"
111 )
112106 self._presence_id_gen = StreamIdGenerator(
113107 db_conn, "presence_stream", "stream_id"
114108 )
123117 )
124118
125119 self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
126 self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
127120 self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
128121 self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
129122 self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
146139 else:
147140 self._cache_id_gen = None
148141
149 events_max = self._stream_id_gen.get_current_token()
150 event_cache_prefill, min_event_val = self._get_cache_dict(
151 db_conn, "events",
152 entity_column="room_id",
153 stream_column="stream_ordering",
154 max_value=events_max,
155 )
156 self._events_stream_cache = StreamChangeCache(
157 "EventsRoomStreamChangeCache", min_event_val,
158 prefilled_cache=event_cache_prefill,
159 )
160
161 self._membership_stream_cache = StreamChangeCache(
162 "MembershipStreamChangeCache", events_max,
163 )
164
165 account_max = self._account_data_id_gen.get_current_token()
166 self._account_data_stream_cache = StreamChangeCache(
167 "AccountDataAndTagsChangeCache", account_max,
168 )
169
170142 self._presence_on_startup = self._get_active_presence(db_conn)
171143
172144 presence_cache_prefill, min_presence_val = self._get_cache_dict(
178150 self.presence_stream_cache = StreamChangeCache(
179151 "PresenceStreamChangeCache", min_presence_val,
180152 prefilled_cache=presence_cache_prefill
181 )
182
183 push_rules_prefill, push_rules_id = self._get_cache_dict(
184 db_conn, "push_rules_stream",
185 entity_column="user_id",
186 stream_column="stream_id",
187 max_value=self._push_rules_stream_id_gen.get_current_token()[0],
188 )
189
190 self.push_rules_stream_cache = StreamChangeCache(
191 "PushRulesStreamChangeCache", push_rules_id,
192 prefilled_cache=push_rules_prefill,
193153 )
194154
195155 max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
226186 "DeviceListFederationStreamChangeCache", device_list_max,
227187 )
228188
189 events_max = self._stream_id_gen.get_current_token()
229190 curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
230191 db_conn, "current_state_delta_stream",
231192 entity_column="room_id",
250211 prefilled_cache=_group_updates_prefill,
251212 )
252213
253 cur = LoggingTransaction(
254 db_conn.cursor(),
255 name="_find_stream_orderings_for_times_txn",
256 database_engine=self.database_engine,
257 after_callbacks=[],
258 final_callbacks=[],
259 )
260 self._find_stream_orderings_for_times_txn(cur)
261 cur.close()
262
263 self.find_stream_orderings_looping_call = self._clock.looping_call(
264 self._find_stream_orderings_for_times, 10 * 60 * 1000
265 )
266
267214 self._stream_order_on_start = self.get_room_max_stream_ordering()
268215 self._min_stream_order_on_start = self.get_room_min_stream_ordering()
269216
270 super(DataStore, self).__init__(hs)
217 super(DataStore, self).__init__(db_conn, hs)
271218
272219 def take_presence_startup_info(self):
273220 active_on_startup = self._presence_on_startup
1515
1616 from synapse.api.errors import StoreError
1717 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
18 from synapse.util.caches import CACHE_SIZE_FACTOR
19 from synapse.util.caches.dictionary_cache import DictionaryCache
2018 from synapse.util.caches.descriptors import Cache
2119 from synapse.storage.engines import PostgresEngine
2220 import synapse.metrics
4947 passed to the constructor. Adds logging and metrics to the .execute()
5048 method."""
5149 __slots__ = [
52 "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
50 "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
5351 ]
5452
5553 def __init__(self, txn, name, database_engine, after_callbacks,
56 final_callbacks):
54 exception_callbacks):
5755 object.__setattr__(self, "txn", txn)
5856 object.__setattr__(self, "name", name)
5957 object.__setattr__(self, "database_engine", database_engine)
6058 object.__setattr__(self, "after_callbacks", after_callbacks)
61 object.__setattr__(self, "final_callbacks", final_callbacks)
59 object.__setattr__(self, "exception_callbacks", exception_callbacks)
6260
6361 def call_after(self, callback, *args, **kwargs):
6462 """Call the given callback on the main twisted thread after the
6765 """
6866 self.after_callbacks.append((callback, args, kwargs))
6967
70 def call_finally(self, callback, *args, **kwargs):
71 self.final_callbacks.append((callback, args, kwargs))
68 def call_on_exception(self, callback, *args, **kwargs):
69 self.exception_callbacks.append((callback, args, kwargs))
7270
7371 def __getattr__(self, name):
7472 return getattr(self.txn, name)
102100 "[SQL values] {%s} %r",
103101 self.name, args[0]
104102 )
105 except:
103 except Exception:
106104 # Don't let logging failures stop SQL from working
107105 pass
108106
161159 class SQLBaseStore(object):
162160 _TXN_ID = 0
163161
164 def __init__(self, hs):
162 def __init__(self, db_conn, hs):
165163 self.hs = hs
166164 self._clock = hs.get_clock()
167165 self._db_pool = hs.get_db_pool()
179177 self._get_event_cache = Cache("*getEvent*", keylen=3,
180178 max_entries=hs.config.event_cache_size)
181179
182 self._state_group_cache = DictionaryCache(
183 "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
184 )
185
186180 self._event_fetch_lock = threading.Condition()
187181 self._event_fetch_list = []
188182 self._event_fetch_ongoing = 0
220214
221215 self._clock.looping_call(loop, 10000)
222216
223 def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
217 def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
224218 logging_context, func, *args, **kwargs):
225219 start = time.time() * 1000
226220 txn_id = self._TXN_ID
241235 txn = conn.cursor()
242236 txn = LoggingTransaction(
243237 txn, name, self.database_engine, after_callbacks,
244 final_callbacks,
238 exception_callbacks,
245239 )
246240 r = func(txn, *args, **kwargs)
247241 conn.commit()
296290
297291 @defer.inlineCallbacks
298292 def runInteraction(self, desc, func, *args, **kwargs):
299 """Wraps the .runInteraction() method on the underlying db_pool."""
293 """Starts a transaction on the database and runs a given function
294
295 Arguments:
296 desc (str): description of the transaction, for logging and metrics
297 func (func): callback function, which will be called with a
298 database transaction (twisted.enterprise.adbapi.Transaction) as
299 its first argument, followed by `args` and `kwargs`.
300
301 args (list): positional args to pass to `func`
302 kwargs (dict): named args to pass to `func`
303
304 Returns:
305 Deferred: The result of func
306 """
300307 current_context = LoggingContext.current_context()
301308
302 start_time = time.time() * 1000
303
304309 after_callbacks = []
305 final_callbacks = []
310 exception_callbacks = []
306311
307312 def inner_func(conn, *args, **kwargs):
308 with LoggingContext("runInteraction") as context:
309 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
310
311 if self.database_engine.is_connection_closed(conn):
312 logger.debug("Reconnecting closed database connection")
313 conn.reconnect()
314
315 current_context.copy_to(context)
316 return self._new_transaction(
317 conn, desc, after_callbacks, final_callbacks, current_context,
318 func, *args, **kwargs
319 )
313 return self._new_transaction(
314 conn, desc, after_callbacks, exception_callbacks, current_context,
315 func, *args, **kwargs
316 )
320317
321318 try:
322 with PreserveLoggingContext():
323 result = yield self._db_pool.runWithConnection(
324 inner_func, *args, **kwargs
325 )
319 result = yield self.runWithConnection(inner_func, *args, **kwargs)
326320
327321 for after_callback, after_args, after_kwargs in after_callbacks:
328322 after_callback(*after_args, **after_kwargs)
329 finally:
330 for after_callback, after_args, after_kwargs in final_callbacks:
323 except: # noqa: E722, as we reraise the exception this is fine.
324 for after_callback, after_args, after_kwargs in exception_callbacks:
331325 after_callback(*after_args, **after_kwargs)
326 raise
332327
333328 defer.returnValue(result)
334329
335330 @defer.inlineCallbacks
336331 def runWithConnection(self, func, *args, **kwargs):
337 """Wraps the .runInteraction() method on the underlying db_pool."""
332 """Wraps the .runWithConnection() method on the underlying db_pool.
333
334 Arguments:
335 func (func): callback function, which will be called with a
336 database connection (twisted.enterprise.adbapi.Connection) as
337 its first argument, followed by `args` and `kwargs`.
338 args (list): positional args to pass to `func`
339 kwargs (dict): named args to pass to `func`
340
341 Returns:
342 Deferred: The result of func
343 """
338344 current_context = LoggingContext.current_context()
339345
340346 start_time = time.time() * 1000
341347
342348 def inner_func(conn, *args, **kwargs):
343349 with LoggingContext("runWithConnection") as context:
344 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
350 sched_duration_ms = time.time() * 1000 - start_time
351 sql_scheduling_timer.inc_by(sched_duration_ms)
352 current_context.add_database_scheduled(sched_duration_ms)
345353
346354 if self.database_engine.is_connection_closed(conn):
347355 logger.debug("Reconnecting closed database connection")
474482
475483 txn.executemany(sql, vals)
476484
485 @defer.inlineCallbacks
477486 def _simple_upsert(self, table, keyvalues, values,
478487 insertion_values={}, desc="_simple_upsert", lock=True):
479488 """
489
490 `lock` should generally be set to True (the default), but can be set
491 to False if either of the following are true:
492
493 * there is a UNIQUE INDEX on the key columns. In this case a conflict
494 will cause an IntegrityError in which case this function will retry
495 the update.
496
497 * we somehow know that we are the only thread which will be updating
498 this table.
499
480500 Args:
481501 table (str): The table to upsert into
482502 keyvalues (dict): The unique key tables and their new values
483503 values (dict): The nonunique columns and their new values
484 insertion_values (dict): key/values to use when inserting
504 insertion_values (dict): additional key/values to use only when
505 inserting
506 lock (bool): True to lock the table when doing the upsert.
485507 Returns:
486508 Deferred(bool): True if a new entry was created, False if an
487509 existing one was updated.
488510 """
489 return self.runInteraction(
490 desc,
491 self._simple_upsert_txn, table, keyvalues, values, insertion_values,
492 lock
493 )
511 attempts = 0
512 while True:
513 try:
514 result = yield self.runInteraction(
515 desc,
516 self._simple_upsert_txn, table, keyvalues, values, insertion_values,
517 lock=lock
518 )
519 defer.returnValue(result)
520 except self.database_engine.module.IntegrityError as e:
521 attempts += 1
522 if attempts >= 5:
523 # don't retry forever, because things other than races
524 # can cause IntegrityErrors
525 raise
526
527 # presumably we raced with another transaction: let's retry.
528 logger.warn(
529 "IntegrityError when upserting into %s; retrying: %s",
530 table, e
531 )
494532
495533 def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
496534 lock=True):
498536 if lock:
499537 self.database_engine.lock_table(txn, table)
500538
501 # Try to update
539 # First try to update.
502540 sql = "UPDATE %s SET %s WHERE %s" % (
503541 table,
504542 ", ".join("%s = ?" % (k,) for k in values),
507545 sqlargs = values.values() + keyvalues.values()
508546
509547 txn.execute(sql, sqlargs)
510 if txn.rowcount == 0:
511 # We didn't update and rows so insert a new one
512 allvalues = {}
513 allvalues.update(keyvalues)
514 allvalues.update(values)
515 allvalues.update(insertion_values)
516
517 sql = "INSERT INTO %s (%s) VALUES (%s)" % (
518 table,
519 ", ".join(k for k in allvalues),
520 ", ".join("?" for _ in allvalues)
521 )
522 txn.execute(sql, allvalues.values())
523
524 return True
525 else:
548 if txn.rowcount > 0:
549 # successfully updated at least one row.
526550 return False
551
552 # We didn't update any rows so insert a new one
553 allvalues = {}
554 allvalues.update(keyvalues)
555 allvalues.update(values)
556 allvalues.update(insertion_values)
557
558 sql = "INSERT INTO %s (%s) VALUES (%s)" % (
559 table,
560 ", ".join(k for k in allvalues),
561 ", ".join("?" for _ in allvalues)
562 )
563 txn.execute(sql, allvalues.values())
564 # successfully inserted
565 return True
527566
528567 def _simple_select_one(self, table, keyvalues, retcols,
529568 allow_none=False, desc="_simple_select_one"):
530569 """Executes a SELECT query on the named table, which is expected to
531 return a single row, returning a single column from it.
570 return a single row, returning multiple columns from it.
532571
533572 Args:
534573 table : string giving the table name
581620
582621 @staticmethod
583622 def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
584 if keyvalues:
585 where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
586 else:
587 where = ""
588
589623 sql = (
590 "SELECT %(retcol)s FROM %(table)s %(where)s"
624 "SELECT %(retcol)s FROM %(table)s"
591625 ) % {
592626 "retcol": retcol,
593627 "table": table,
594 "where": where,
595628 }
596629
597 txn.execute(sql, keyvalues.values())
630 if keyvalues:
631 sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
632 txn.execute(sql, keyvalues.values())
633 else:
634 txn.execute(sql)
598635
599636 return [r[0] for r in txn]
600637
605642
606643 Args:
607644 table (str): table name
608 keyvalues (dict): column names and values to select the rows with
645 keyvalues (dict|None): column names and values to select the rows with
609646 retcol (str): column whos value we wish to retrieve.
610647
611648 Returns:
9631000 # __exit__ called after the transaction finishes.
9641001 ctx = self._cache_id_gen.get_next()
9651002 stream_id = ctx.__enter__()
966 txn.call_finally(ctx.__exit__, None, None, None)
1003 txn.call_on_exception(ctx.__exit__, None, None, None)
1004 txn.call_after(ctx.__exit__, None, None, None)
9671005 txn.call_after(self.hs.get_notifier().on_new_replication_data)
9681006
9691007 self._simple_insert_txn(
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
15 from ._base import SQLBaseStore
1616 from twisted.internet import defer
1717
18 from synapse.storage._base import SQLBaseStore
19 from synapse.storage.util.id_generators import StreamIdGenerator
20
21 from synapse.util.caches.stream_change_cache import StreamChangeCache
1822 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
1923
20 import ujson as json
24 import abc
25 import simplejson as json
2126 import logging
2227
2328 logger = logging.getLogger(__name__)
2429
2530
26 class AccountDataStore(SQLBaseStore):
31 class AccountDataWorkerStore(SQLBaseStore):
32 """This is an abstract base class where subclasses must implement
33 `get_max_account_data_stream_id` which can be called in the initializer.
34 """
35
36 # This ABCMeta metaclass ensures that we cannot be instantiated without
37 # the abstract methods being implemented.
38 __metaclass__ = abc.ABCMeta
39
40 def __init__(self, db_conn, hs):
41 account_max = self.get_max_account_data_stream_id()
42 self._account_data_stream_cache = StreamChangeCache(
43 "AccountDataAndTagsChangeCache", account_max,
44 )
45
46 super(AccountDataWorkerStore, self).__init__(db_conn, hs)
47
48 @abc.abstractmethod
49 def get_max_account_data_stream_id(self):
50 """Get the current max stream ID for account data stream
51
52 Returns:
53 int
54 """
55 raise NotImplementedError()
2756
2857 @cached()
2958 def get_account_data_for_user(self, user_id):
6291 "get_account_data_for_user", get_account_data_for_user_txn
6392 )
6493
65 @cachedInlineCallbacks(num_args=2)
94 @cachedInlineCallbacks(num_args=2, max_entries=5000)
6695 def get_global_account_data_by_type_for_user(self, data_type, user_id):
6796 """
6897 Returns:
103132 for row in rows
104133 })
105134
135 @cached(num_args=2)
106136 def get_account_data_for_room(self, user_id, room_id):
107137 """Get all the client account_data for a user for a room.
108138
124154
125155 return self.runInteraction(
126156 "get_account_data_for_room", get_account_data_for_room_txn
157 )
158
159 @cached(num_args=3, max_entries=5000)
160 def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
161 """Get the client account_data of given type for a user for a room.
162
163 Args:
164 user_id(str): The user to get the account_data for.
165 room_id(str): The room to get the account_data for.
166 account_data_type (str): The account data type to get.
167 Returns:
168 A deferred of the room account_data for that type, or None if
169 there isn't any set.
170 """
171 def get_account_data_for_room_and_type_txn(txn):
172 content_json = self._simple_select_one_onecol_txn(
173 txn,
174 table="room_account_data",
175 keyvalues={
176 "user_id": user_id,
177 "room_id": room_id,
178 "account_data_type": account_data_type,
179 },
180 retcol="content",
181 allow_none=True
182 )
183
184 return json.loads(content_json) if content_json else None
185
186 return self.runInteraction(
187 "get_account_data_for_room_and_type",
188 get_account_data_for_room_and_type_txn,
127189 )
128190
129191 def get_all_updated_account_data(self, last_global_id, last_room_id,
208270 "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
209271 )
210272
273 @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
274 def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
275 ignored_account_data = yield self.get_global_account_data_by_type_for_user(
276 "m.ignored_user_list", ignorer_user_id,
277 on_invalidate=cache_context.invalidate,
278 )
279 if not ignored_account_data:
280 defer.returnValue(False)
281
282 defer.returnValue(
283 ignored_user_id in ignored_account_data.get("ignored_users", {})
284 )
285
286
287 class AccountDataStore(AccountDataWorkerStore):
288 def __init__(self, db_conn, hs):
289 self._account_data_id_gen = StreamIdGenerator(
290 db_conn, "account_data_max_stream_id", "stream_id"
291 )
292
293 super(AccountDataStore, self).__init__(db_conn, hs)
294
295 def get_max_account_data_stream_id(self):
296 """Get the current max stream id for the private user data stream
297
298 Returns:
299 A deferred int.
300 """
301 return self._account_data_id_gen.get_current_token()
302
211303 @defer.inlineCallbacks
212304 def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
213305 """Add some account_data to a room for a user.
221313 """
222314 content_json = json.dumps(content)
223315
224 def add_account_data_txn(txn, next_id):
225 self._simple_upsert_txn(
226 txn,
316 with self._account_data_id_gen.get_next() as next_id:
317 # no need to lock here as room_account_data has a unique constraint
318 # on (user_id, room_id, account_data_type) so _simple_upsert will
319 # retry if there is a conflict.
320 yield self._simple_upsert(
321 desc="add_room_account_data",
227322 table="room_account_data",
228323 keyvalues={
229324 "user_id": user_id,
233328 values={
234329 "stream_id": next_id,
235330 "content": content_json,
236 }
237 )
238 txn.call_after(
239 self._account_data_stream_cache.entity_has_changed,
240 user_id, next_id,
241 )
242 txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
243 self._update_max_stream_id(txn, next_id)
244
245 with self._account_data_id_gen.get_next() as next_id:
246 yield self.runInteraction(
247 "add_room_account_data", add_account_data_txn, next_id
331 },
332 lock=False,
333 )
334
335 # it's theoretically possible for the above to succeed and the
336 # below to fail - in which case we might reuse a stream id on
337 # restart, and the above update might not get propagated. That
338 # doesn't sound any worse than the whole update getting lost,
339 # which is what would happen if we combined the two into one
340 # transaction.
341 yield self._update_max_stream_id(next_id)
342
343 self._account_data_stream_cache.entity_has_changed(user_id, next_id)
344 self.get_account_data_for_user.invalidate((user_id,))
345 self.get_account_data_for_room.invalidate((user_id, room_id,))
346 self.get_account_data_for_room_and_type.prefill(
347 (user_id, room_id, account_data_type,), content,
248348 )
249349
250350 result = self._account_data_id_gen.get_current_token()
262362 """
263363 content_json = json.dumps(content)
264364
265 def add_account_data_txn(txn, next_id):
266 self._simple_upsert_txn(
267 txn,
365 with self._account_data_id_gen.get_next() as next_id:
366 # no need to lock here as account_data has a unique constraint on
367 # (user_id, account_data_type) so _simple_upsert will retry if
368 # there is a conflict.
369 yield self._simple_upsert(
370 desc="add_user_account_data",
268371 table="account_data",
269372 keyvalues={
270373 "user_id": user_id,
273376 values={
274377 "stream_id": next_id,
275378 "content": content_json,
276 }
277 )
278 txn.call_after(
279 self._account_data_stream_cache.entity_has_changed,
379 },
380 lock=False,
381 )
382
383 # it's theoretically possible for the above to succeed and the
384 # below to fail - in which case we might reuse a stream id on
385 # restart, and the above update might not get propagated. That
386 # doesn't sound any worse than the whole update getting lost,
387 # which is what would happen if we combined the two into one
388 # transaction.
389 yield self._update_max_stream_id(next_id)
390
391 self._account_data_stream_cache.entity_has_changed(
280392 user_id, next_id,
281393 )
282 txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
283 txn.call_after(
284 self.get_global_account_data_by_type_for_user.invalidate,
394 self.get_account_data_for_user.invalidate((user_id,))
395 self.get_global_account_data_by_type_for_user.invalidate(
285396 (account_data_type, user_id,)
286 )
287 self._update_max_stream_id(txn, next_id)
288
289 with self._account_data_id_gen.get_next() as next_id:
290 yield self.runInteraction(
291 "add_user_account_data", add_account_data_txn, next_id
292397 )
293398
294399 result = self._account_data_id_gen.get_current_token()
295400 defer.returnValue(result)
296401
297 def _update_max_stream_id(self, txn, next_id):
402 def _update_max_stream_id(self, next_id):
298403 """Update the max stream_id
299404
300405 Args:
301 txn: The database cursor
302406 next_id(int): The the revision to advance to.
303407 """
304 update_max_id_sql = (
305 "UPDATE account_data_max_stream_id"
306 " SET stream_id = ?"
307 " WHERE stream_id < ?"
308 )
309 txn.execute(update_max_id_sql, (next_id, next_id))
310
311 @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
312 def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
313 ignored_account_data = yield self.get_global_account_data_by_type_for_user(
314 "m.ignored_user_list", ignorer_user_id,
315 on_invalidate=cache_context.invalidate,
316 )
317 if not ignored_account_data:
318 defer.returnValue(False)
319
320 defer.returnValue(
321 ignored_user_id in ignored_account_data.get("ignored_users", {})
322 )
408 def _update(txn):
409 update_max_id_sql = (
410 "UPDATE account_data_max_stream_id"
411 " SET stream_id = ?"
412 " WHERE stream_id < ?"
413 )
414 txn.execute(update_max_id_sql, (next_id, next_id))
415 return self.runInteraction(
416 "update_account_data_max_stream_id",
417 _update,
418 )
00 # -*- coding: utf-8 -*-
11 # Copyright 2015, 2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1617 import simplejson as json
1718 from twisted.internet import defer
1819
19 from synapse.api.constants import Membership
2020 from synapse.appservice import AppServiceTransaction
2121 from synapse.config.appservice import load_appservices
22 from synapse.storage.roommember import RoomsForUser
22 from synapse.storage.events import EventsWorkerStore
2323 from ._base import SQLBaseStore
2424
2525
4545 return exclusive_user_regex
4646
4747
48 class ApplicationServiceStore(SQLBaseStore):
49
50 def __init__(self, hs):
51 super(ApplicationServiceStore, self).__init__(hs)
52 self.hostname = hs.hostname
48 class ApplicationServiceWorkerStore(SQLBaseStore):
49 def __init__(self, db_conn, hs):
5350 self.services_cache = load_appservices(
5451 hs.hostname,
5552 hs.config.app_service_config_files
5653 )
5754 self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
55
56 super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
5857
5958 def get_app_services(self):
6059 return self.services_cache
9897 return service
9998 return None
10099
101 def get_app_service_rooms(self, service):
102 """Get a list of RoomsForUser for this application service.
103
104 Application services may be "interested" in lots of rooms depending on
105 the room ID, the room aliases, or the members in the room. This function
106 takes all of these into account and returns a list of RoomsForUser which
107 represent the entire list of room IDs that this application service
108 wants to know about.
109
110 Args:
111 service: The application service to get a room list for.
112 Returns:
113 A list of RoomsForUser.
114 """
115 return self.runInteraction(
116 "get_app_service_rooms",
117 self._get_app_service_rooms_txn,
118 service,
119 )
120
121 def _get_app_service_rooms_txn(self, txn, service):
122 # get all rooms matching the room ID regex.
123 room_entries = self._simple_select_list_txn(
124 txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
125 )
126 matching_room_list = set([
127 r["room_id"] for r in room_entries if
128 service.is_interested_in_room(r["room_id"])
129 ])
130
131 # resolve room IDs for matching room alias regex.
132 room_alias_mappings = self._simple_select_list_txn(
133 txn=txn, table="room_aliases", keyvalues=None,
134 retcols=["room_id", "room_alias"]
135 )
136 matching_room_list |= set([
137 r["room_id"] for r in room_alias_mappings if
138 service.is_interested_in_alias(r["room_alias"])
139 ])
140
141 # get all rooms for every user for this AS. This is scoped to users on
142 # this HS only.
143 user_list = self._simple_select_list_txn(
144 txn=txn, table="users", keyvalues=None, retcols=["name"]
145 )
146 user_list = [
147 u["name"] for u in user_list if
148 service.is_interested_in_user(u["name"])
149 ]
150 rooms_for_user_matching_user_id = set() # RoomsForUser list
151 for user_id in user_list:
152 # FIXME: This assumes this store is linked with RoomMemberStore :(
153 rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
154 txn=txn,
155 user_id=user_id,
156 membership_list=[Membership.JOIN]
157 )
158 rooms_for_user_matching_user_id |= set(rooms_for_user)
159
160 # make RoomsForUser tuples for room ids and aliases which are not in the
161 # main rooms_for_user_list - e.g. they are rooms which do not have AS
162 # registered users in it.
163 known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
164 missing_rooms_for_user = [
165 RoomsForUser(r, service.sender, "join") for r in
166 matching_room_list if r not in known_room_ids
167 ]
168 rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
169
170 return rooms_for_user_matching_user_id
171
172
173 class ApplicationServiceTransactionStore(SQLBaseStore):
174
175 def __init__(self, hs):
176 super(ApplicationServiceTransactionStore, self).__init__(hs)
177
100 def get_app_service_by_id(self, as_id):
101 """Get the application service with the given appservice ID.
102
103 Args:
104 as_id (str): The application service ID.
105 Returns:
106 synapse.appservice.ApplicationService or None.
107 """
108 for service in self.services_cache:
109 if service.id == as_id:
110 return service
111 return None
112
113
114 class ApplicationServiceStore(ApplicationServiceWorkerStore):
115 # This is currently empty due to there not being any AS storage functions
116 # that can't be run on the workers. Since this may change in future, and
117 # to keep consistency with the other stores, we keep this empty class for
118 # now.
119 pass
120
121
122 class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
123 EventsWorkerStore):
178124 @defer.inlineCallbacks
179125 def get_appservices_by_state(self, state):
180126 """Get a list of application services based on their state.
419365 events = yield self._get_events(event_ids)
420366
421367 defer.returnValue((upper_bound, events))
368
369
370 class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
371 # This is currently empty due to there not being any AS storage functions
372 # that can't be run on the workers. Since this may change in future, and
373 # to keep consistency with the other stores, we keep this empty class for
374 # now.
375 pass
1818
1919 from twisted.internet import defer
2020
21 import ujson as json
21 import simplejson as json
2222 import logging
2323
2424 logger = logging.getLogger(__name__)
7979 BACKGROUND_UPDATE_INTERVAL_MS = 1000
8080 BACKGROUND_UPDATE_DURATION_MS = 100
8181
82 def __init__(self, hs):
83 super(BackgroundUpdateStore, self).__init__(hs)
82 def __init__(self, db_conn, hs):
83 super(BackgroundUpdateStore, self).__init__(db_conn, hs)
8484 self._background_update_performance = {}
8585 self._background_update_queue = []
8686 self._background_update_handlers = {}
87 self._all_done = False
8788
8889 @defer.inlineCallbacks
8990 def start_doing_background_updates(self):
9798 result = yield self.do_next_background_update(
9899 self.BACKGROUND_UPDATE_DURATION_MS
99100 )
100 except:
101 except Exception:
101102 logger.exception("Error doing update")
102103 else:
103104 if result is None:
105106 "No more background updates to do."
106107 " Unscheduling background update task."
107108 )
109 self._all_done = True
108110 defer.returnValue(None)
111
112 @defer.inlineCallbacks
113 def has_completed_background_updates(self):
114 """Check if all the background updates have completed
115
116 Returns:
117 Deferred[bool]: True if all background updates have completed
118 """
119 # if we've previously determined that there is nothing left to do, that
120 # is easy
121 if self._all_done:
122 defer.returnValue(True)
123
124 # obviously, if we have things in our queue, we're not done.
125 if self._background_update_queue:
126 defer.returnValue(False)
127
128 # otherwise, check if there are updates to be run. This is important,
129 # as we may be running on a worker which doesn't perform the bg updates
130 # itself, but still wants to wait for them to happen.
131 updates = yield self._simple_select_onecol(
132 "background_updates",
133 keyvalues=None,
134 retcol="1",
135 desc="check_background_updates",
136 )
137 if not updates:
138 self._all_done = True
139 defer.returnValue(True)
140
141 defer.returnValue(False)
109142
110143 @defer.inlineCallbacks
111144 def do_next_background_update(self, desired_duration_ms):
207240 update_handler(function): The function that does the update.
208241 """
209242 self._background_update_handlers[update_name] = update_handler
243
244 def register_noop_background_update(self, update_name):
245 """Register a noop handler for a background update.
246
247 This is useful when we previously did a background update, but no
248 longer wish to do the update. In this case the background update should
249 be removed from the schema delta files, but there may still be some
250 users who have the background update queued, so this method should
251 also be called to clear the update.
252
253 Args:
254 update_name (str): Name of update
255 """
256 @defer.inlineCallbacks
257 def noop_update(progress, batch_size):
258 yield self._end_background_update(update_name)
259 defer.returnValue(1)
260
261 self.register_background_update_handler(update_name, noop_update)
210262
211263 def register_background_index_update(self, update_name, index_name,
212264 table, columns, where_clause=None,
268320 # Sqlite doesn't support concurrent creation of indexes.
269321 #
270322 # We don't use partial indices on SQLite as it wasn't introduced
271 # until 3.8, and wheezy has 3.7
323 # until 3.8, and wheezy and CentOS 7 have 3.7
272324 #
273325 # We assume that sqlite doesn't give us invalid indices; however
274326 # we may still end up with the index existing but the
3131
3232
3333 class ClientIpStore(background_updates.BackgroundUpdateStore):
34 def __init__(self, hs):
34 def __init__(self, db_conn, hs):
3535 self.client_ip_last_seen = Cache(
3636 name="client_ip_last_seen",
3737 keylen=4,
3838 max_entries=50000 * CACHE_SIZE_FACTOR,
3939 )
4040
41 super(ClientIpStore, self).__init__(hs)
41 super(ClientIpStore, self).__init__(db_conn, hs)
4242
4343 self.register_background_index_update(
4444 "user_ips_device_index",
1313 # limitations under the License.
1414
1515 import logging
16 import ujson
16 import simplejson
1717
1818 from twisted.internet import defer
1919
2828 class DeviceInboxStore(BackgroundUpdateStore):
2929 DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
3030
31 def __init__(self, hs):
32 super(DeviceInboxStore, self).__init__(hs)
31 def __init__(self, db_conn, hs):
32 super(DeviceInboxStore, self).__init__(db_conn, hs)
3333
3434 self.register_background_index_update(
3535 "device_inbox_stream_index",
8484 )
8585 rows = []
8686 for destination, edu in remote_messages_by_destination.items():
87 edu_json = ujson.dumps(edu)
87 edu_json = simplejson.dumps(edu)
8888 rows.append((destination, stream_id, now_ms, edu_json))
8989 txn.executemany(sql, rows)
9090
176176 " WHERE user_id = ?"
177177 )
178178 txn.execute(sql, (user_id,))
179 message_json = ujson.dumps(messages_by_device["*"])
179 message_json = simplejson.dumps(messages_by_device["*"])
180180 for row in txn:
181181 # Add the message for all devices for this user on this
182182 # server.
198198 # Only insert into the local inbox if the device exists on
199199 # this server
200200 device = row[0]
201 message_json = ujson.dumps(messages_by_device[device])
201 message_json = simplejson.dumps(messages_by_device[device])
202202 messages_json_for_user[device] = message_json
203203
204204 if messages_json_for_user:
252252 messages = []
253253 for row in txn:
254254 stream_pos = row[0]
255 messages.append(ujson.loads(row[1]))
255 messages.append(simplejson.loads(row[1]))
256256 if len(messages) < limit:
257257 stream_pos = current_stream_id
258258 return (messages, stream_pos)
388388 messages = []
389389 for row in txn:
390390 stream_pos = row[0]
391 messages.append(ujson.loads(row[1]))
391 messages.append(simplejson.loads(row[1]))
392392 if len(messages) < limit:
393393 stream_pos = current_stream_id
394394 return (messages, stream_pos)
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 import ujson as json
15 import simplejson as json
1616
1717 from twisted.internet import defer
1818
2525
2626
2727 class DeviceStore(SQLBaseStore):
28 def __init__(self, hs):
29 super(DeviceStore, self).__init__(hs)
28 def __init__(self, db_conn, hs):
29 super(DeviceStore, self).__init__(db_conn, hs)
3030
3131 # Map of (user_id, device_id) -> bool. If there is an entry that implies
3232 # the device exists.
2828 )
2929
3030
31 class DirectoryStore(SQLBaseStore):
32
31 class DirectoryWorkerStore(SQLBaseStore):
3332 @defer.inlineCallbacks
3433 def get_association_from_room_alias(self, room_alias):
3534 """ Get's the room_id and server list for a given room_alias
6867 RoomAliasMapping(room_id, room_alias.to_string(), servers)
6968 )
7069
70 def get_room_alias_creator(self, room_alias):
71 return self._simple_select_one_onecol(
72 table="room_aliases",
73 keyvalues={
74 "room_alias": room_alias,
75 },
76 retcol="creator",
77 desc="get_room_alias_creator",
78 allow_none=True
79 )
80
81 @cached(max_entries=5000)
82 def get_aliases_for_room(self, room_id):
83 return self._simple_select_onecol(
84 "room_aliases",
85 {"room_id": room_id},
86 "room_alias",
87 desc="get_aliases_for_room",
88 )
89
90
91 class DirectoryStore(DirectoryWorkerStore):
7192 @defer.inlineCallbacks
7293 def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
7394 """ Creates an associatin between a room alias and room_id/servers
115136 )
116137 defer.returnValue(ret)
117138
118 def get_room_alias_creator(self, room_alias):
119 return self._simple_select_one_onecol(
120 table="room_aliases",
121 keyvalues={
122 "room_alias": room_alias,
123 },
124 retcol="creator",
125 desc="get_room_alias_creator",
126 allow_none=True
127 )
128
129139 @defer.inlineCallbacks
130140 def delete_room_alias(self, room_alias):
131141 room_id = yield self.runInteraction(
134144 room_alias,
135145 )
136146
137 self.get_aliases_for_room.invalidate((room_id,))
138147 defer.returnValue(room_id)
139148
140149 def _delete_room_alias_txn(self, txn, room_alias):
159168 (room_alias.to_string(),)
160169 )
161170
171 self._invalidate_cache_and_stream(
172 txn, self.get_aliases_for_room, (room_id,)
173 )
174
162175 return room_id
163
164 @cached(max_entries=5000)
165 def get_aliases_for_room(self, room_id):
166 return self._simple_select_onecol(
167 "room_aliases",
168 {"room_id": room_id},
169 "room_alias",
170 desc="get_aliases_for_room",
171 )
172176
173177 def update_aliases_for_room(self, old_room_id, new_room_id, creator):
174178 def _update_aliases_for_room_txn(txn):
1616 from synapse.util.caches.descriptors import cached
1717
1818 from canonicaljson import encode_canonical_json
19 import ujson as json
19 import simplejson as json
2020
2121 from ._base import SQLBaseStore
2222
6161
6262 def lock_table(self, txn, table):
6363 txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
64
65 def get_next_state_group_id(self, txn):
66 """Returns an int that can be used as a new state_group ID
67 """
68 txn.execute("SELECT nextval('state_group_id_seq')")
69 return txn.fetchone()[0]
1515 from synapse.storage.prepare_database import prepare_database
1616
1717 import struct
18 import threading
1819
1920
2021 class Sqlite3Engine(object):
2223
2324 def __init__(self, database_module, database_config):
2425 self.module = database_module
26
27 # The current max state_group, or None if we haven't looked
28 # in the DB yet.
29 self._current_state_group_id = None
30 self._current_state_group_id_lock = threading.Lock()
2531
2632 def check_database(self, txn):
2733 pass
4147
4248 def lock_table(self, txn, table):
4349 return
50
51 def get_next_state_group_id(self, txn):
52 """Returns an int that can be used as a new state_group ID
53 """
54 # We do application locking here since if we're using sqlite then
55 # we are a single process synapse.
56 with self._current_state_group_id_lock:
57 if self._current_state_group_id is None:
58 txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
59 self._current_state_group_id = txn.fetchone()[0]
60
61 self._current_state_group_id += 1
62 return self._current_state_group_id
4463
4564
4665 # Following functions taken from: https://github.com/coleifer/peewee
1414
1515 from twisted.internet import defer
1616
17 from ._base import SQLBaseStore
17 from synapse.storage._base import SQLBaseStore
18 from synapse.storage.events import EventsWorkerStore
19 from synapse.storage.signatures import SignatureWorkerStore
20
1821 from synapse.api.errors import StoreError
1922 from synapse.util.caches.descriptors import cached
2023 from unpaddedbase64 import encode_base64
2629 logger = logging.getLogger(__name__)
2730
2831
29 class EventFederationStore(SQLBaseStore):
30 """ Responsible for storing and serving up the various graphs associated
31 with an event. Including the main event graph and the auth chains for an
32 event.
33
34 Also has methods for getting the front (latest) and back (oldest) edges
35 of the event graphs. These are used to generate the parents for new events
36 and backfilling from another server respectively.
37 """
38
39 EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
40
41 def __init__(self, hs):
42 super(EventFederationStore, self).__init__(hs)
43
44 self.register_background_update_handler(
45 self.EVENT_AUTH_STATE_ONLY,
46 self._background_delete_non_state_event_auth,
47 )
48
49 hs.get_clock().looping_call(
50 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
51 )
52
32 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
33 SQLBaseStore):
5334 def get_auth_chain(self, event_ids, include_given=False):
5435 """Get auth events for given event_ids. The events *must* be state events.
5536
227208
228209 return int(min_depth) if min_depth is not None else None
229210
211 def get_forward_extremeties_for_room(self, room_id, stream_ordering):
212 """For a given room_id and stream_ordering, return the forward
213 extremeties of the room at that point in "time".
214
215 Throws a StoreError if we have since purged the index for
216 stream_orderings from that point.
217
218 Args:
219 room_id (str):
220 stream_ordering (int):
221
222 Returns:
223 deferred, which resolves to a list of event_ids
224 """
225 # We want to make the cache more effective, so we clamp to the last
226 # change before the given ordering.
227 last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
228
229 # We don't always have a full stream_to_exterm_id table, e.g. after
230 # the upgrade that introduced it, so we make sure we never ask for a
231 # stream_ordering from before a restart
232 last_change = max(self._stream_order_on_start, last_change)
233
234 # provided the last_change is recent enough, we now clamp the requested
235 # stream_ordering to it.
236 if last_change > self.stream_ordering_month_ago:
237 stream_ordering = min(last_change, stream_ordering)
238
239 return self._get_forward_extremeties_for_room(room_id, stream_ordering)
240
241 @cached(max_entries=5000, num_args=2)
242 def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
243 """For a given room_id and stream_ordering, return the forward
244 extremeties of the room at that point in "time".
245
246 Throws a StoreError if we have since purged the index for
247 stream_orderings from that point.
248 """
249
250 if stream_ordering <= self.stream_ordering_month_ago:
251 raise StoreError(400, "stream_ordering too old")
252
253 sql = ("""
254 SELECT event_id FROM stream_ordering_to_exterm
255 INNER JOIN (
256 SELECT room_id, MAX(stream_ordering) AS stream_ordering
257 FROM stream_ordering_to_exterm
258 WHERE stream_ordering <= ? GROUP BY room_id
259 ) AS rms USING (room_id, stream_ordering)
260 WHERE room_id = ?
261 """)
262
263 def get_forward_extremeties_for_room_txn(txn):
264 txn.execute(sql, (stream_ordering, room_id))
265 return [event_id for event_id, in txn]
266
267 return self.runInteraction(
268 "get_forward_extremeties_for_room",
269 get_forward_extremeties_for_room_txn
270 )
271
272 def get_backfill_events(self, room_id, event_list, limit):
273 """Get a list of Events for a given topic that occurred before (and
274 including) the events in event_list. Return a list of max size `limit`
275
276 Args:
277 txn
278 room_id (str)
279 event_list (list)
280 limit (int)
281 """
282 return self.runInteraction(
283 "get_backfill_events",
284 self._get_backfill_events, room_id, event_list, limit
285 ).addCallback(
286 self._get_events
287 ).addCallback(
288 lambda l: sorted(l, key=lambda e: -e.depth)
289 )
290
291 def _get_backfill_events(self, txn, room_id, event_list, limit):
292 logger.debug(
293 "_get_backfill_events: %s, %s, %s",
294 room_id, repr(event_list), limit
295 )
296
297 event_results = set()
298
299 # We want to make sure that we do a breadth-first, "depth" ordered
300 # search.
301
302 query = (
303 "SELECT depth, prev_event_id FROM event_edges"
304 " INNER JOIN events"
305 " ON prev_event_id = events.event_id"
306 " AND event_edges.room_id = events.room_id"
307 " WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
308 " AND event_edges.is_state = ?"
309 " LIMIT ?"
310 )
311
312 queue = PriorityQueue()
313
314 for event_id in event_list:
315 depth = self._simple_select_one_onecol_txn(
316 txn,
317 table="events",
318 keyvalues={
319 "event_id": event_id,
320 },
321 retcol="depth",
322 allow_none=True,
323 )
324
325 if depth:
326 queue.put((-depth, event_id))
327
328 while not queue.empty() and len(event_results) < limit:
329 try:
330 _, event_id = queue.get_nowait()
331 except Empty:
332 break
333
334 if event_id in event_results:
335 continue
336
337 event_results.add(event_id)
338
339 txn.execute(
340 query,
341 (room_id, event_id, False, limit - len(event_results))
342 )
343
344 for row in txn:
345 if row[1] not in event_results:
346 queue.put((-row[0], row[1]))
347
348 return event_results
349
350 @defer.inlineCallbacks
351 def get_missing_events(self, room_id, earliest_events, latest_events,
352 limit, min_depth):
353 ids = yield self.runInteraction(
354 "get_missing_events",
355 self._get_missing_events,
356 room_id, earliest_events, latest_events, limit, min_depth
357 )
358
359 events = yield self._get_events(ids)
360
361 events = sorted(
362 [ev for ev in events if ev.depth >= min_depth],
363 key=lambda e: e.depth,
364 )
365
366 defer.returnValue(events[:limit])
367
368 def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
369 limit, min_depth):
370
371 earliest_events = set(earliest_events)
372 front = set(latest_events) - earliest_events
373
374 event_results = set()
375
376 query = (
377 "SELECT prev_event_id FROM event_edges "
378 "WHERE room_id = ? AND event_id = ? AND is_state = ? "
379 "LIMIT ?"
380 )
381
382 while front and len(event_results) < limit:
383 new_front = set()
384 for event_id in front:
385 txn.execute(
386 query,
387 (room_id, event_id, False, limit - len(event_results))
388 )
389
390 for e_id, in txn:
391 new_front.add(e_id)
392
393 new_front -= earliest_events
394 new_front -= event_results
395
396 front = new_front
397 event_results |= new_front
398
399 return event_results
400
401
402 class EventFederationStore(EventFederationWorkerStore):
403 """ Responsible for storing and serving up the various graphs associated
404 with an event. Including the main event graph and the auth chains for an
405 event.
406
407 Also has methods for getting the front (latest) and back (oldest) edges
408 of the event graphs. These are used to generate the parents for new events
409 and backfilling from another server respectively.
410 """
411
412 EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
413
414 def __init__(self, db_conn, hs):
415 super(EventFederationStore, self).__init__(db_conn, hs)
416
417 self.register_background_update_handler(
418 self.EVENT_AUTH_STATE_ONLY,
419 self._background_delete_non_state_event_auth,
420 )
421
422 hs.get_clock().looping_call(
423 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
424 )
425
230426 def _update_min_depth_for_room_txn(self, txn, room_id, depth):
231427 min_depth = self._get_min_depth_interaction(txn, room_id)
232428
307503 (ev.event_id, ev.room_id) for ev in events
308504 if not ev.internal_metadata.is_outlier()
309505 ]
310 )
311
312 def get_forward_extremeties_for_room(self, room_id, stream_ordering):
313 """For a given room_id and stream_ordering, return the forward
314 extremeties of the room at that point in "time".
315
316 Throws a StoreError if we have since purged the index for
317 stream_orderings from that point.
318
319 Args:
320 room_id (str):
321 stream_ordering (int):
322
323 Returns:
324 deferred, which resolves to a list of event_ids
325 """
326 # We want to make the cache more effective, so we clamp to the last
327 # change before the given ordering.
328 last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
329
330 # We don't always have a full stream_to_exterm_id table, e.g. after
331 # the upgrade that introduced it, so we make sure we never ask for a
332 # stream_ordering from before a restart
333 last_change = max(self._stream_order_on_start, last_change)
334
335 # provided the last_change is recent enough, we now clamp the requested
336 # stream_ordering to it.
337 if last_change > self.stream_ordering_month_ago:
338 stream_ordering = min(last_change, stream_ordering)
339
340 return self._get_forward_extremeties_for_room(room_id, stream_ordering)
341
342 @cached(max_entries=5000, num_args=2)
343 def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
344 """For a given room_id and stream_ordering, return the forward
345 extremeties of the room at that point in "time".
346
347 Throws a StoreError if we have since purged the index for
348 stream_orderings from that point.
349 """
350
351 if stream_ordering <= self.stream_ordering_month_ago:
352 raise StoreError(400, "stream_ordering too old")
353
354 sql = ("""
355 SELECT event_id FROM stream_ordering_to_exterm
356 INNER JOIN (
357 SELECT room_id, MAX(stream_ordering) AS stream_ordering
358 FROM stream_ordering_to_exterm
359 WHERE stream_ordering <= ? GROUP BY room_id
360 ) AS rms USING (room_id, stream_ordering)
361 WHERE room_id = ?
362 """)
363
364 def get_forward_extremeties_for_room_txn(txn):
365 txn.execute(sql, (stream_ordering, room_id))
366 return [event_id for event_id, in txn]
367
368 return self.runInteraction(
369 "get_forward_extremeties_for_room",
370 get_forward_extremeties_for_room_txn
371506 )
372507
373508 def _delete_old_forward_extrem_cache(self):
392527 _delete_old_forward_extrem_cache_txn
393528 )
394529
395 def get_backfill_events(self, room_id, event_list, limit):
396 """Get a list of Events for a given topic that occurred before (and
397 including) the events in event_list. Return a list of max size `limit`
398
399 Args:
400 txn
401 room_id (str)
402 event_list (list)
403 limit (int)
404 """
405 return self.runInteraction(
406 "get_backfill_events",
407 self._get_backfill_events, room_id, event_list, limit
408 ).addCallback(
409 self._get_events
410 ).addCallback(
411 lambda l: sorted(l, key=lambda e: -e.depth)
412 )
413
414 def _get_backfill_events(self, txn, room_id, event_list, limit):
415 logger.debug(
416 "_get_backfill_events: %s, %s, %s",
417 room_id, repr(event_list), limit
418 )
419
420 event_results = set()
421
422 # We want to make sure that we do a breadth-first, "depth" ordered
423 # search.
424
425 query = (
426 "SELECT depth, prev_event_id FROM event_edges"
427 " INNER JOIN events"
428 " ON prev_event_id = events.event_id"
429 " AND event_edges.room_id = events.room_id"
430 " WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
431 " AND event_edges.is_state = ?"
432 " LIMIT ?"
433 )
434
435 queue = PriorityQueue()
436
437 for event_id in event_list:
438 depth = self._simple_select_one_onecol_txn(
439 txn,
440 table="events",
441 keyvalues={
442 "event_id": event_id,
443 },
444 retcol="depth",
445 allow_none=True,
446 )
447
448 if depth:
449 queue.put((-depth, event_id))
450
451 while not queue.empty() and len(event_results) < limit:
452 try:
453 _, event_id = queue.get_nowait()
454 except Empty:
455 break
456
457 if event_id in event_results:
458 continue
459
460 event_results.add(event_id)
461
462 txn.execute(
463 query,
464 (room_id, event_id, False, limit - len(event_results))
465 )
466
467 for row in txn:
468 if row[1] not in event_results:
469 queue.put((-row[0], row[1]))
470
471 return event_results
472
473 @defer.inlineCallbacks
474 def get_missing_events(self, room_id, earliest_events, latest_events,
475 limit, min_depth):
476 ids = yield self.runInteraction(
477 "get_missing_events",
478 self._get_missing_events,
479 room_id, earliest_events, latest_events, limit, min_depth
480 )
481
482 events = yield self._get_events(ids)
483
484 events = sorted(
485 [ev for ev in events if ev.depth >= min_depth],
486 key=lambda e: e.depth,
487 )
488
489 defer.returnValue(events[:limit])
490
491 def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
492 limit, min_depth):
493
494 earliest_events = set(earliest_events)
495 front = set(latest_events) - earliest_events
496
497 event_results = set()
498
499 query = (
500 "SELECT prev_event_id FROM event_edges "
501 "WHERE room_id = ? AND event_id = ? AND is_state = ? "
502 "LIMIT ?"
503 )
504
505 while front and len(event_results) < limit:
506 new_front = set()
507 for event_id in front:
508 txn.execute(
509 query,
510 (room_id, event_id, False, limit - len(event_results))
511 )
512
513 for e_id, in txn:
514 new_front.add(e_id)
515
516 new_front -= earliest_events
517 new_front -= event_results
518
519 front = new_front
520 event_results |= new_front
521
522 return event_results
523
524530 def clean_room_for_join(self, room_id):
525531 return self.runInteraction(
526532 "clean_room_for_join",
00 # -*- coding: utf-8 -*-
11 # Copyright 2015 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
15 from ._base import SQLBaseStore
16 from synapse.storage._base import SQLBaseStore, LoggingTransaction
1617 from twisted.internet import defer
1718 from synapse.util.async import sleep
1819 from synapse.util.caches.descriptors import cachedInlineCallbacks
2021 from .stream import lower_bound
2122
2223 import logging
23 import ujson as json
24 import simplejson as json
2425
2526 logger = logging.getLogger(__name__)
2627
6162 return DEFAULT_NOTIF_ACTION
6263
6364
64 class EventPushActionsStore(SQLBaseStore):
65 EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
66
67 def __init__(self, hs):
68 super(EventPushActionsStore, self).__init__(hs)
69
70 self.register_background_index_update(
71 self.EPA_HIGHLIGHT_INDEX,
72 index_name="event_push_actions_u_highlight",
73 table="event_push_actions",
74 columns=["user_id", "stream_ordering"],
75 )
76
77 self.register_background_index_update(
78 "event_push_actions_highlights_index",
79 index_name="event_push_actions_highlights_index",
80 table="event_push_actions",
81 columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
82 where_clause="highlight=1"
83 )
84
85 self._doing_notif_rotation = False
86 self._rotate_notif_loop = self._clock.looping_call(
87 self._rotate_notifs, 30 * 60 * 1000
88 )
89
90 def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
91 """
92 Args:
93 event: the event set actions for
94 tuples: list of tuples of (user_id, actions)
95 """
96 values = []
97 for uid, actions in tuples:
98 is_highlight = 1 if _action_has_highlight(actions) else 0
99
100 values.append({
101 'room_id': event.room_id,
102 'event_id': event.event_id,
103 'user_id': uid,
104 'actions': _serialize_action(actions, is_highlight),
105 'stream_ordering': event.internal_metadata.stream_ordering,
106 'topological_ordering': event.depth,
107 'notif': 1,
108 'highlight': is_highlight,
109 })
110
111 for uid, __ in tuples:
112 txn.call_after(
113 self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
114 (event.room_id, uid)
115 )
116 self._simple_insert_many_txn(txn, "event_push_actions", values)
65 class EventPushActionsWorkerStore(SQLBaseStore):
66 def __init__(self, db_conn, hs):
67 super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
68
69 # These get correctly set by _find_stream_orderings_for_times_txn
70 self.stream_ordering_month_ago = None
71 self.stream_ordering_day_ago = None
72
73 cur = LoggingTransaction(
74 db_conn.cursor(),
75 name="_find_stream_orderings_for_times_txn",
76 database_engine=self.database_engine,
77 after_callbacks=[],
78 exception_callbacks=[],
79 )
80 self._find_stream_orderings_for_times_txn(cur)
81 cur.close()
82
83 self.find_stream_orderings_looping_call = self._clock.looping_call(
84 self._find_stream_orderings_for_times, 10 * 60 * 1000
85 )
11786
11887 @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
11988 def get_unread_event_push_actions_by_room_for_user(
431400 # Now return the first `limit`
432401 defer.returnValue(notifs[:limit])
433402
403 def add_push_actions_to_staging(self, event_id, user_id_actions):
404 """Add the push actions for the event to the push action staging area.
405
406 Args:
407 event_id (str)
408 user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
409 user_id to list of push actions, where an action can either be
410 a string or dict.
411
412 Returns:
413 Deferred
414 """
415
416 if not user_id_actions:
417 return
418
419 # This is a helper function for generating the necessary tuple that
420 # can be used to inert into the `event_push_actions_staging` table.
421 def _gen_entry(user_id, actions):
422 is_highlight = 1 if _action_has_highlight(actions) else 0
423 return (
424 event_id, # event_id column
425 user_id, # user_id column
426 _serialize_action(actions, is_highlight), # actions column
427 1, # notif column
428 is_highlight, # highlight column
429 )
430
431 def _add_push_actions_to_staging_txn(txn):
432 # We don't use _simple_insert_many here to avoid the overhead
433 # of generating lists of dicts.
434
435 sql = """
436 INSERT INTO event_push_actions_staging
437 (event_id, user_id, actions, notif, highlight)
438 VALUES (?, ?, ?, ?, ?)
439 """
440
441 txn.executemany(sql, (
442 _gen_entry(user_id, actions)
443 for user_id, actions in user_id_actions.iteritems()
444 ))
445
446 return self.runInteraction(
447 "add_push_actions_to_staging", _add_push_actions_to_staging_txn
448 )
449
450 def remove_push_actions_from_staging(self, event_id):
451 """Called if we failed to persist the event to ensure that stale push
452 actions don't build up in the DB
453
454 Args:
455 event_id (str)
456 """
457
458 return self._simple_delete(
459 table="event_push_actions_staging",
460 keyvalues={
461 "event_id": event_id,
462 },
463 desc="remove_push_actions_from_staging",
464 )
465
466 @defer.inlineCallbacks
467 def _find_stream_orderings_for_times(self):
468 yield self.runInteraction(
469 "_find_stream_orderings_for_times",
470 self._find_stream_orderings_for_times_txn
471 )
472
473 def _find_stream_orderings_for_times_txn(self, txn):
474 logger.info("Searching for stream ordering 1 month ago")
475 self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
476 txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
477 )
478 logger.info(
479 "Found stream ordering 1 month ago: it's %d",
480 self.stream_ordering_month_ago
481 )
482 logger.info("Searching for stream ordering 1 day ago")
483 self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
484 txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
485 )
486 logger.info(
487 "Found stream ordering 1 day ago: it's %d",
488 self.stream_ordering_day_ago
489 )
490
491 def find_first_stream_ordering_after_ts(self, ts):
492 """Gets the stream ordering corresponding to a given timestamp.
493
494 Specifically, finds the stream_ordering of the first event that was
495 received on or after the timestamp. This is done by a binary search on
496 the events table, since there is no index on received_ts, so is
497 relatively slow.
498
499 Args:
500 ts (int): timestamp in millis
501
502 Returns:
503 Deferred[int]: stream ordering of the first event received on/after
504 the timestamp
505 """
506 return self.runInteraction(
507 "_find_first_stream_ordering_after_ts_txn",
508 self._find_first_stream_ordering_after_ts_txn,
509 ts,
510 )
511
512 @staticmethod
513 def _find_first_stream_ordering_after_ts_txn(txn, ts):
514 """
515 Find the stream_ordering of the first event that was received on or
516 after a given timestamp. This is relatively slow as there is no index
517 on received_ts but we can then use this to delete push actions before
518 this.
519
520 received_ts must necessarily be in the same order as stream_ordering
521 and stream_ordering is indexed, so we manually binary search using
522 stream_ordering
523
524 Args:
525 txn (twisted.enterprise.adbapi.Transaction):
526 ts (int): timestamp to search for
527
528 Returns:
529 int: stream ordering
530 """
531 txn.execute("SELECT MAX(stream_ordering) FROM events")
532 max_stream_ordering = txn.fetchone()[0]
533
534 if max_stream_ordering is None:
535 return 0
536
537 # We want the first stream_ordering in which received_ts is greater
538 # than or equal to ts. Call this point X.
539 #
540 # We maintain the invariants:
541 #
542 # range_start <= X <= range_end
543 #
544 range_start = 0
545 range_end = max_stream_ordering + 1
546
547 # Given a stream_ordering, look up the timestamp at that
548 # stream_ordering.
549 #
550 # The array may be sparse (we may be missing some stream_orderings).
551 # We treat the gaps as the same as having the same value as the
552 # preceding entry, because we will pick the lowest stream_ordering
553 # which satisfies our requirement of received_ts >= ts.
554 #
555 # For example, if our array of events indexed by stream_ordering is
556 # [10, <none>, 20], we should treat this as being equivalent to
557 # [10, 10, 20].
558 #
559 sql = (
560 "SELECT received_ts FROM events"
561 " WHERE stream_ordering <= ?"
562 " ORDER BY stream_ordering DESC"
563 " LIMIT 1"
564 )
565
566 while range_end - range_start > 0:
567 middle = (range_end + range_start) // 2
568 txn.execute(sql, (middle,))
569 row = txn.fetchone()
570 if row is None:
571 # no rows with stream_ordering<=middle
572 range_start = middle + 1
573 continue
574
575 middle_ts = row[0]
576 if ts > middle_ts:
577 # we got a timestamp lower than the one we were looking for.
578 # definitely need to look higher: X > middle.
579 range_start = middle + 1
580 else:
581 # we got a timestamp higher than (or the same as) the one we
582 # were looking for. We aren't yet sure about the point we
583 # looked up, but we can be sure that X <= middle.
584 range_end = middle
585
586 return range_end
587
588
589 class EventPushActionsStore(EventPushActionsWorkerStore):
590 EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
591
592 def __init__(self, db_conn, hs):
593 super(EventPushActionsStore, self).__init__(db_conn, hs)
594
595 self.register_background_index_update(
596 self.EPA_HIGHLIGHT_INDEX,
597 index_name="event_push_actions_u_highlight",
598 table="event_push_actions",
599 columns=["user_id", "stream_ordering"],
600 )
601
602 self.register_background_index_update(
603 "event_push_actions_highlights_index",
604 index_name="event_push_actions_highlights_index",
605 table="event_push_actions",
606 columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
607 where_clause="highlight=1"
608 )
609
610 self._doing_notif_rotation = False
611 self._rotate_notif_loop = self._clock.looping_call(
612 self._rotate_notifs, 30 * 60 * 1000
613 )
614
615 def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
616 all_events_and_contexts):
617 """Handles moving push actions from staging table to main
618 event_push_actions table for all events in `events_and_contexts`.
619
620 Also ensures that all events in `all_events_and_contexts` are removed
621 from the push action staging area.
622
623 Args:
624 events_and_contexts (list[(EventBase, EventContext)]): events
625 we are persisting
626 all_events_and_contexts (list[(EventBase, EventContext)]): all
627 events that we were going to persist. This includes events
628 we've already persisted, etc, that wouldn't appear in
629 events_and_context.
630 """
631
632 sql = """
633 INSERT INTO event_push_actions (
634 room_id, event_id, user_id, actions, stream_ordering,
635 topological_ordering, notif, highlight
636 )
637 SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
638 FROM event_push_actions_staging
639 WHERE event_id = ?
640 """
641
642 if events_and_contexts:
643 txn.executemany(sql, (
644 (
645 event.room_id, event.internal_metadata.stream_ordering,
646 event.depth, event.event_id,
647 )
648 for event, _ in events_and_contexts
649 ))
650
651 for event, _ in events_and_contexts:
652 user_ids = self._simple_select_onecol_txn(
653 txn,
654 table="event_push_actions_staging",
655 keyvalues={
656 "event_id": event.event_id,
657 },
658 retcol="user_id",
659 )
660
661 for uid in user_ids:
662 txn.call_after(
663 self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
664 (event.room_id, uid,)
665 )
666
667 # Now we delete the staging area for *all* events that were being
668 # persisted.
669 txn.executemany(
670 "DELETE FROM event_push_actions_staging WHERE event_id = ?",
671 (
672 (event.event_id,)
673 for event, _ in all_events_and_contexts
674 )
675 )
676
434677 @defer.inlineCallbacks
435678 def get_push_actions_for_user(self, user_id, before=None, limit=50,
436679 only_highlight=False):
548791 DELETE FROM event_push_summary
549792 WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
550793 """, (room_id, user_id, stream_ordering))
551
552 @defer.inlineCallbacks
553 def _find_stream_orderings_for_times(self):
554 yield self.runInteraction(
555 "_find_stream_orderings_for_times",
556 self._find_stream_orderings_for_times_txn
557 )
558
559 def _find_stream_orderings_for_times_txn(self, txn):
560 logger.info("Searching for stream ordering 1 month ago")
561 self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
562 txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
563 )
564 logger.info(
565 "Found stream ordering 1 month ago: it's %d",
566 self.stream_ordering_month_ago
567 )
568 logger.info("Searching for stream ordering 1 day ago")
569 self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
570 txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
571 )
572 logger.info(
573 "Found stream ordering 1 day ago: it's %d",
574 self.stream_ordering_day_ago
575 )
576
577 def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
578 """
579 Find the stream_ordering of the first event that was received after
580 a given timestamp. This is relatively slow as there is no index on
581 received_ts but we can then use this to delete push actions before
582 this.
583
584 received_ts must necessarily be in the same order as stream_ordering
585 and stream_ordering is indexed, so we manually binary search using
586 stream_ordering
587 """
588 txn.execute("SELECT MAX(stream_ordering) FROM events")
589 max_stream_ordering = txn.fetchone()[0]
590
591 if max_stream_ordering is None:
592 return 0
593
594 range_start = 0
595 range_end = max_stream_ordering
596
597 sql = (
598 "SELECT received_ts FROM events"
599 " WHERE stream_ordering > ?"
600 " ORDER BY stream_ordering"
601 " LIMIT 1"
602 )
603
604 while range_end - range_start > 1:
605 middle = int((range_end + range_start) / 2)
606 txn.execute(sql, (middle,))
607 middle_ts = txn.fetchone()[0]
608 if ts > middle_ts:
609 range_start = middle
610 else:
611 range_end = middle
612
613 return range_end
614794
615795 @defer.inlineCallbacks
616796 def _rotate_notifs(self):
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1112 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
14 from ._base import SQLBaseStore
15
16 from twisted.internet import defer, reactor
17
18 from synapse.events import FrozenEvent, USE_FROZEN_DICTS
19 from synapse.events.utils import prune_event
15
16 from synapse.storage.events_worker import EventsWorkerStore
17
18 from twisted.internet import defer
19
20 from synapse.events import USE_FROZEN_DICTS
2021
2122 from synapse.util.async import ObservableDeferred
2223 from synapse.util.logcontext import (
23 preserve_fn, PreserveLoggingContext, make_deferred_yieldable
24 PreserveLoggingContext, make_deferred_yieldable
2425 )
2526 from synapse.util.logutils import log_function
2627 from synapse.util.metrics import Measure
2728 from synapse.api.constants import EventTypes
2829 from synapse.api.errors import SynapseError
29 from synapse.state import resolve_events
30 from synapse.util.caches.descriptors import cached
30 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
3131 from synapse.types import get_domain_from_id
3232
3333 from canonicaljson import encode_canonical_json
3737 import synapse.metrics
3838
3939 import logging
40 import ujson as json
40 import simplejson as json
4141
4242 # these are only included to make the type annotations work
4343 from synapse.events import EventBase # noqa: F401
5555
5656 def encode_json(json_object):
5757 if USE_FROZEN_DICTS:
58 # ujson doesn't like frozen_dicts
5958 return encode_canonical_json(json_object)
6059 else:
6160 return json.dumps(json_object, ensure_ascii=False)
62
63
64 # These values are used in the `enqueus_event` and `_do_fetch` methods to
65 # control how we batch/bulk fetch events from the database.
66 # The values are plucked out of thing air to make initial sync run faster
67 # on jki.re
68 # TODO: Make these configurable.
69 EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
70 EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
71 EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
7261
7362
7463 class _EventPeristenceQueue(object):
10998 end_item.events_and_contexts.extend(events_and_contexts)
11099 return end_item.deferred.observe()
111100
112 deferred = ObservableDeferred(defer.Deferred())
101 deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
113102
114103 queue.append(self._EventPersistQueueItem(
115104 events_and_contexts=events_and_contexts,
145134 try:
146135 queue = self._get_drainining_queue(room_id)
147136 for item in queue:
137 # handle_queue_loop runs in the sentinel logcontext, so
138 # there is no need to preserve_fn when running the
139 # callbacks on the deferred.
148140 try:
149141 ret = yield per_item_callback(item)
150142 item.deferred.callback(ret)
151 except Exception as e:
152 item.deferred.errback(e)
143 except Exception:
144 item.deferred.errback()
153145 finally:
154146 queue = self._event_persist_queues.pop(room_id, None)
155147 if queue:
156148 self._event_persist_queues[room_id] = queue
157149 self._currently_persisting_rooms.discard(room_id)
158150
159 preserve_fn(handle_queue_loop)()
151 # set handle_queue_loop off on the background. We don't want to
152 # attribute work done in it to the current request, so we drop the
153 # logcontext altogether.
154 with PreserveLoggingContext():
155 handle_queue_loop()
160156
161157 def _get_drainining_queue(self, room_id):
162158 queue = self._event_persist_queues.setdefault(room_id, deque())
192188 return f
193189
194190
195 class EventsStore(SQLBaseStore):
191 class EventsStore(EventsWorkerStore):
196192 EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
197193 EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
198194
199 def __init__(self, hs):
200 super(EventsStore, self).__init__(hs)
201 self._clock = hs.get_clock()
195 def __init__(self, db_conn, hs):
196 super(EventsStore, self).__init__(db_conn, hs)
202197 self.register_background_update_handler(
203198 self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
204199 )
229224
230225 self._event_persist_queue = _EventPeristenceQueue()
231226
227 self._state_resolution_handler = hs.get_state_resolution_handler()
228
232229 def persist_events(self, events_and_contexts, backfilled=False):
233230 """
234231 Write events to the database
284281 def _maybe_start_persisting(self, room_id):
285282 @defer.inlineCallbacks
286283 def persisting_queue(item):
287 yield self._persist_events(
288 item.events_and_contexts,
289 backfilled=item.backfilled,
290 )
284 with Measure(self._clock, "persist_events"):
285 yield self._persist_events(
286 item.events_and_contexts,
287 backfilled=item.backfilled,
288 )
291289
292290 self._event_persist_queue.handle_queue(room_id, persisting_queue)
293291
334332
335333 # NB: Assumes that we are only persisting events for one room
336334 # at a time.
335
336 # map room_id->list[event_ids] giving the new forward
337 # extremities in each room
337338 new_forward_extremeties = {}
339
340 # map room_id->(type,state_key)->event_id tracking the full
341 # state in each room after adding these events
338342 current_state_for_room = {}
343
344 # map room_id->(to_delete, to_insert) where each entry is
345 # a map (type,key)->event_id giving the state delta in each
346 # room
347 state_delta_for_room = {}
348
339349 if not backfilled:
340350 with Measure(self._clock, "_calculate_state_and_extrem"):
341351 # Work out the new "current state" for each room.
378388 if all_single_prev_not_state:
379389 continue
380390
381 state = yield self._calculate_state_delta(
382 room_id, ev_ctx_rm, new_latest_event_ids
391 logger.info(
392 "Calculating state delta for room %s", room_id,
383393 )
384 if state:
385 current_state_for_room[room_id] = state
394 current_state = yield self._get_new_state_after_events(
395 room_id,
396 ev_ctx_rm, new_latest_event_ids,
397 )
398 if current_state is not None:
399 current_state_for_room[room_id] = current_state
400 delta = yield self._calculate_state_delta(
401 room_id, current_state,
402 )
403 if delta is not None:
404 state_delta_for_room[room_id] = delta
386405
387406 yield self.runInteraction(
388407 "persist_events",
390409 events_and_contexts=chunk,
391410 backfilled=backfilled,
392411 delete_existing=delete_existing,
393 current_state_for_room=current_state_for_room,
412 state_delta_for_room=state_delta_for_room,
394413 new_forward_extremeties=new_forward_extremeties,
395414 )
396415 persist_event_counter.inc_by(len(chunk))
407426
408427 event_counter.inc(event.type, origin_type, origin_entity)
409428
410 for room_id, (_, _, new_state) in current_state_for_room.iteritems():
429 for room_id, new_state in current_state_for_room.iteritems():
411430 self.get_current_state_ids.prefill(
412431 (room_id, ), new_state
413432 )
459478 defer.returnValue(new_latest_event_ids)
460479
461480 @defer.inlineCallbacks
462 def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
463 """Calculate the new state deltas for a room.
464
465 Assumes that we are only persisting events for one room at a time.
481 def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
482 """Calculate the current state dict after adding some new events to
483 a room
484
485 Args:
486 room_id (str):
487 room to which the events are being added. Used for logging etc
488
489 events_context (list[(EventBase, EventContext)]):
490 events and contexts which are being added to the room
491
492 new_latest_event_ids (iterable[str]):
493 the new forward extremities for the room.
466494
467495 Returns:
468 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
469 i.e. (type, state_key) -> event_id. `to_delete` are the entries to
470 first be deleted from current_state_events, `to_insert` are entries
471 to insert. `new_state` is the full set of state.
472 May return None if there are no changes to be applied.
473 """
474 # Now we need to work out the different state sets for
475 # each state extremities
476 state_sets = []
477 state_groups = set()
496 Deferred[dict[(str,str), str]|None]:
497 None if there are no changes to the room state, or
498 a dict of (type, state_key) -> event_id].
499 """
500
501 if not new_latest_event_ids:
502 defer.returnValue({})
503
504 # map from state_group to ((type, key) -> event_id) state map
505 state_groups = {}
478506 missing_event_ids = []
479507 was_updated = False
480508 for event_id in new_latest_event_ids:
485513 if ctx.current_state_ids is None:
486514 raise Exception("Unknown current state")
487515
516 if ctx.state_group is None:
517 # I don't think this can happen, but let's double-check
518 raise Exception(
519 "Context for new extremity event %s has no state "
520 "group" % (event_id, ),
521 )
522
488523 # If we've already seen the state group don't bother adding
489524 # it to the state sets again
490525 if ctx.state_group not in state_groups:
491 state_sets.append(ctx.current_state_ids)
526 state_groups[ctx.state_group] = ctx.current_state_ids
492527 if ctx.delta_ids or hasattr(ev, "state_key"):
493528 was_updated = True
494 if ctx.state_group:
495 # Add this as a seen state group (if it has a state
496 # group)
497 state_groups.add(ctx.state_group)
498529 break
499530 else:
500531 # If we couldn't find it, then we'll need to pull
502533 was_updated = True
503534 missing_event_ids.append(event_id)
504535
536 if not was_updated:
537 return
538
505539 if missing_event_ids:
506540 # Now pull out the state for any missing events from DB
507541 event_to_groups = yield self._get_state_group_for_events(
508542 missing_event_ids,
509543 )
510544
511 groups = set(event_to_groups.itervalues()) - state_groups
545 groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
512546
513547 if groups:
514548 group_to_state = yield self._get_state_for_groups(groups)
515 state_sets.extend(group_to_state.itervalues())
516
517 if not new_latest_event_ids:
518 current_state = {}
519 elif was_updated:
520 if len(state_sets) == 1:
521 # If there is only one state set, then we know what the current
522 # state is.
523 current_state = state_sets[0]
524 else:
525 # We work out the current state by passing the state sets to the
526 # state resolution algorithm. It may ask for some events, including
527 # the events we have yet to persist, so we need a slightly more
528 # complicated event lookup function than simply looking the events
529 # up in the db.
530 events_map = {ev.event_id: ev for ev, _ in events_context}
531
532 @defer.inlineCallbacks
533 def get_events(ev_ids):
534 # We get the events by first looking at the list of events we
535 # are trying to persist, and then fetching the rest from the DB.
536 db = []
537 to_return = {}
538 for ev_id in ev_ids:
539 ev = events_map.get(ev_id, None)
540 if ev:
541 to_return[ev_id] = ev
542 else:
543 db.append(ev_id)
544
545 if db:
546 evs = yield self.get_events(
547 ev_ids, get_prev_content=False, check_redacted=False,
548 )
549 to_return.update(evs)
550 defer.returnValue(to_return)
551
552 current_state = yield resolve_events(
553 state_sets,
554 state_map_factory=get_events,
555 )
556 else:
557 return
558
549 state_groups.update(group_to_state)
550
551 if len(state_groups) == 1:
552 # If there is only one state group, then we know what the current
553 # state is.
554 defer.returnValue(state_groups.values()[0])
555
556 def get_events(ev_ids):
557 return self.get_events(
558 ev_ids, get_prev_content=False, check_redacted=False,
559 )
560 events_map = {ev.event_id: ev for ev, _ in events_context}
561 logger.debug("calling resolve_state_groups from preserve_events")
562 res = yield self._state_resolution_handler.resolve_state_groups(
563 room_id, state_groups, events_map, get_events
564 )
565
566 defer.returnValue(res.state)
567
568 @defer.inlineCallbacks
569 def _calculate_state_delta(self, room_id, current_state):
570 """Calculate the new state deltas for a room.
571
572 Assumes that we are only persisting events for one room at a time.
573
574 Returns:
575 2-tuple (to_delete, to_insert) where both are state dicts,
576 i.e. (type, state_key) -> event_id. `to_delete` are the entries to
577 first be deleted from current_state_events, `to_insert` are entries
578 to insert.
579 """
559580 existing_state = yield self.get_current_state_ids(room_id)
560581
561582 existing_events = set(existing_state.itervalues())
575596 if ev_id in events_to_insert
576597 }
577598
578 defer.returnValue((to_delete, to_insert, current_state))
579
580 @defer.inlineCallbacks
581 def get_event(self, event_id, check_redacted=True,
582 get_prev_content=False, allow_rejected=False,
583 allow_none=False):
584 """Get an event from the database by event_id.
585
586 Args:
587 event_id (str): The event_id of the event to fetch
588 check_redacted (bool): If True, check if event has been redacted
589 and redact it.
590 get_prev_content (bool): If True and event is a state event,
591 include the previous states content in the unsigned field.
592 allow_rejected (bool): If True return rejected events.
593 allow_none (bool): If True, return None if no event found, if
594 False throw an exception.
595
596 Returns:
597 Deferred : A FrozenEvent.
598 """
599 events = yield self._get_events(
600 [event_id],
601 check_redacted=check_redacted,
602 get_prev_content=get_prev_content,
603 allow_rejected=allow_rejected,
604 )
605
606 if not events and not allow_none:
607 raise SynapseError(404, "Could not find event %s" % (event_id,))
608
609 defer.returnValue(events[0] if events else None)
610
611 @defer.inlineCallbacks
612 def get_events(self, event_ids, check_redacted=True,
613 get_prev_content=False, allow_rejected=False):
614 """Get events from the database
615
616 Args:
617 event_ids (list): The event_ids of the events to fetch
618 check_redacted (bool): If True, check if event has been redacted
619 and redact it.
620 get_prev_content (bool): If True and event is a state event,
621 include the previous states content in the unsigned field.
622 allow_rejected (bool): If True return rejected events.
623
624 Returns:
625 Deferred : Dict from event_id to event.
626 """
627 events = yield self._get_events(
628 event_ids,
629 check_redacted=check_redacted,
630 get_prev_content=get_prev_content,
631 allow_rejected=allow_rejected,
632 )
633
634 defer.returnValue({e.event_id: e for e in events})
599 defer.returnValue((to_delete, to_insert))
635600
636601 @log_function
637602 def _persist_events_txn(self, txn, events_and_contexts, backfilled,
638 delete_existing=False, current_state_for_room={},
603 delete_existing=False, state_delta_for_room={},
639604 new_forward_extremeties={}):
640605 """Insert some number of room events into the necessary database tables.
641606
651616 delete_existing (bool): True to purge existing table rows for the
652617 events from the database. This is useful when retrying due to
653618 IntegrityError.
654 current_state_for_room (dict[str, (list[str], list[str])]):
619 state_delta_for_room (dict[str, (list[str], list[str])]):
655620 The current-state delta for each room. For each room, a tuple
656621 (to_delete, to_insert), being a list of event ids to be removed
657622 from the current state, and a list of event ids to be added to
661626 list of the event ids which are the forward extremities.
662627
663628 """
629 all_events_and_contexts = events_and_contexts
630
664631 max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
665632
666 self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
633 self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
667634
668635 self._update_forward_extremities_txn(
669636 txn,
707674 events_and_contexts=events_and_contexts,
708675 )
709676
710 # Insert into the state_groups, state_groups_state, and
711 # event_to_state_groups tables.
712 self._store_mult_state_groups_txn(txn, events_and_contexts)
677 # Insert into event_to_state_groups.
678 self._store_event_state_mappings_txn(txn, events_and_contexts)
713679
714680 # _store_rejected_events_txn filters out any events which were
715681 # rejected, and returns the filtered list.
724690 self._update_metadata_tables_txn(
725691 txn,
726692 events_and_contexts=events_and_contexts,
693 all_events_and_contexts=all_events_and_contexts,
727694 backfilled=backfilled,
728695 )
729696
730697 def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
731698 for room_id, current_state_tuple in state_delta_by_room.iteritems():
732 to_delete, to_insert, _ = current_state_tuple
699 to_delete, to_insert = current_state_tuple
733700 txn.executemany(
734701 "DELETE FROM current_state_events WHERE event_id = ?",
735702 [(ev_id,) for ev_id in to_delete.itervalues()],
786753
787754 for member in members_changed:
788755 self._invalidate_cache_and_stream(
789 txn, self.get_rooms_for_user, (member,)
756 txn, self.get_rooms_for_user_with_stream_ordering, (member,)
790757 )
791758
792759 for host in set(get_domain_from_id(u) for u in members_changed):
944911 # an outlier in the database. We now have some state at that
945912 # so we need to update the state_groups table with that state.
946913
947 # insert into the state_group, state_groups_state and
948 # event_to_state_groups tables.
914 # insert into event_to_state_groups.
949915 try:
950 self._store_mult_state_groups_txn(txn, ((event, context),))
916 self._store_event_state_mappings_txn(txn, ((event, context),))
951917 except Exception:
952918 logger.exception("")
953919 raise
11221088 ec for ec in events_and_contexts if ec[0] not in to_remove
11231089 ]
11241090
1125 def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
1091 def _update_metadata_tables_txn(self, txn, events_and_contexts,
1092 all_events_and_contexts, backfilled):
11261093 """Update all the miscellaneous tables for new events
11271094
11281095 Args:
11291096 txn (twisted.enterprise.adbapi.Connection): db connection
11301097 events_and_contexts (list[(EventBase, EventContext)]): events
11311098 we are persisting
1099 all_events_and_contexts (list[(EventBase, EventContext)]): all
1100 events that we were going to persist. This includes events
1101 we've already persisted, etc, that wouldn't appear in
1102 events_and_context.
11321103 backfilled (bool): True if the events were backfilled
11331104 """
1105
1106 # Insert all the push actions into the event_push_actions table.
1107 self._set_push_actions_for_event_and_users_txn(
1108 txn,
1109 events_and_contexts=events_and_contexts,
1110 all_events_and_contexts=all_events_and_contexts,
1111 )
11341112
11351113 if not events_and_contexts:
11361114 # nothing to do here
11371115 return
11381116
11391117 for event, context in events_and_contexts:
1140 # Insert all the push actions into the event_push_actions table.
1141 if context.push_actions:
1142 self._set_push_actions_for_event_and_users_txn(
1143 txn, event, context.push_actions
1144 )
1145
11461118 if event.type == EventTypes.Redaction and event.redacts is not None:
11471119 # Remove the entries in the event_push_actions table for the
11481120 # redacted event.
13471319 )
13481320
13491321 @defer.inlineCallbacks
1350 def _get_events(self, event_ids, check_redacted=True,
1351 get_prev_content=False, allow_rejected=False):
1352 if not event_ids:
1353 defer.returnValue([])
1354
1355 event_id_list = event_ids
1356 event_ids = set(event_ids)
1357
1358 event_entry_map = self._get_events_from_cache(
1359 event_ids,
1360 allow_rejected=allow_rejected,
1361 )
1362
1363 missing_events_ids = [e for e in event_ids if e not in event_entry_map]
1364
1365 if missing_events_ids:
1366 missing_events = yield self._enqueue_events(
1367 missing_events_ids,
1368 check_redacted=check_redacted,
1369 allow_rejected=allow_rejected,
1370 )
1371
1372 event_entry_map.update(missing_events)
1373
1374 events = []
1375 for event_id in event_id_list:
1376 entry = event_entry_map.get(event_id, None)
1377 if not entry:
1378 continue
1379
1380 if allow_rejected or not entry.event.rejected_reason:
1381 if check_redacted and entry.redacted_event:
1382 event = entry.redacted_event
1383 else:
1384 event = entry.event
1385
1386 events.append(event)
1387
1388 if get_prev_content:
1389 if "replaces_state" in event.unsigned:
1390 prev = yield self.get_event(
1391 event.unsigned["replaces_state"],
1392 get_prev_content=False,
1393 allow_none=True,
1394 )
1395 if prev:
1396 event.unsigned = dict(event.unsigned)
1397 event.unsigned["prev_content"] = prev.content
1398 event.unsigned["prev_sender"] = prev.sender
1399
1400 defer.returnValue(events)
1401
1402 def _invalidate_get_event_cache(self, event_id):
1403 self._get_event_cache.invalidate((event_id,))
1404
1405 def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
1406 """Fetch events from the caches
1407
1408 Args:
1409 events (list(str)): list of event_ids to fetch
1410 allow_rejected (bool): Whether to teturn events that were rejected
1411 update_metrics (bool): Whether to update the cache hit ratio metrics
1412
1413 Returns:
1414 dict of event_id -> _EventCacheEntry for each event_id in cache. If
1415 allow_rejected is `False` then there will still be an entry but it
1416 will be `None`
1417 """
1418 event_map = {}
1419
1420 for event_id in events:
1421 ret = self._get_event_cache.get(
1422 (event_id,), None,
1423 update_metrics=update_metrics,
1424 )
1425 if not ret:
1426 continue
1427
1428 if allow_rejected or not ret.event.rejected_reason:
1429 event_map[event_id] = ret
1430 else:
1431 event_map[event_id] = None
1432
1433 return event_map
1434
1435 def _do_fetch(self, conn):
1436 """Takes a database connection and waits for requests for events from
1437 the _event_fetch_list queue.
1438 """
1439 event_list = []
1440 i = 0
1441 while True:
1442 try:
1443 with self._event_fetch_lock:
1444 event_list = self._event_fetch_list
1445 self._event_fetch_list = []
1446
1447 if not event_list:
1448 single_threaded = self.database_engine.single_threaded
1449 if single_threaded or i > EVENT_QUEUE_ITERATIONS:
1450 self._event_fetch_ongoing -= 1
1451 return
1452 else:
1453 self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
1454 i += 1
1455 continue
1456 i = 0
1457
1458 event_id_lists = zip(*event_list)[0]
1459 event_ids = [
1460 item for sublist in event_id_lists for item in sublist
1461 ]
1462
1463 rows = self._new_transaction(
1464 conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
1465 )
1466
1467 row_dict = {
1468 r["event_id"]: r
1469 for r in rows
1470 }
1471
1472 # We only want to resolve deferreds from the main thread
1473 def fire(lst, res):
1474 for ids, d in lst:
1475 if not d.called:
1476 try:
1477 with PreserveLoggingContext():
1478 d.callback([
1479 res[i]
1480 for i in ids
1481 if i in res
1482 ])
1483 except:
1484 logger.exception("Failed to callback")
1485 with PreserveLoggingContext():
1486 reactor.callFromThread(fire, event_list, row_dict)
1487 except Exception as e:
1488 logger.exception("do_fetch")
1489
1490 # We only want to resolve deferreds from the main thread
1491 def fire(evs):
1492 for _, d in evs:
1493 if not d.called:
1494 with PreserveLoggingContext():
1495 d.errback(e)
1496
1497 if event_list:
1498 with PreserveLoggingContext():
1499 reactor.callFromThread(fire, event_list)
1500
1501 @defer.inlineCallbacks
1502 def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
1503 """Fetches events from the database using the _event_fetch_list. This
1504 allows batch and bulk fetching of events - it allows us to fetch events
1505 without having to create a new transaction for each request for events.
1506 """
1507 if not events:
1508 defer.returnValue({})
1509
1510 events_d = defer.Deferred()
1511 with self._event_fetch_lock:
1512 self._event_fetch_list.append(
1513 (events, events_d)
1514 )
1515
1516 self._event_fetch_lock.notify()
1517
1518 if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
1519 self._event_fetch_ongoing += 1
1520 should_start = True
1521 else:
1522 should_start = False
1523
1524 if should_start:
1525 with PreserveLoggingContext():
1526 self.runWithConnection(
1527 self._do_fetch
1528 )
1529
1530 logger.debug("Loading %d events", len(events))
1531 with PreserveLoggingContext():
1532 rows = yield events_d
1533 logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
1534
1535 if not allow_rejected:
1536 rows[:] = [r for r in rows if not r["rejects"]]
1537
1538 res = yield make_deferred_yieldable(defer.gatherResults(
1539 [
1540 preserve_fn(self._get_event_from_row)(
1541 row["internal_metadata"], row["json"], row["redacts"],
1542 rejected_reason=row["rejects"],
1543 )
1544 for row in rows
1545 ],
1546 consumeErrors=True
1547 ))
1548
1549 defer.returnValue({
1550 e.event.event_id: e
1551 for e in res if e
1552 })
1553
1554 def _fetch_event_rows(self, txn, events):
1555 rows = []
1556 N = 200
1557 for i in range(1 + len(events) / N):
1558 evs = events[i * N:(i + 1) * N]
1559 if not evs:
1560 break
1561
1562 sql = (
1563 "SELECT "
1564 " e.event_id as event_id, "
1565 " e.internal_metadata,"
1566 " e.json,"
1567 " r.redacts as redacts,"
1568 " rej.event_id as rejects "
1569 " FROM event_json as e"
1570 " LEFT JOIN rejections as rej USING (event_id)"
1571 " LEFT JOIN redactions as r ON e.event_id = r.redacts"
1572 " WHERE e.event_id IN (%s)"
1573 ) % (",".join(["?"] * len(evs)),)
1574
1575 txn.execute(sql, evs)
1576 rows.extend(self.cursor_to_dict(txn))
1577
1578 return rows
1579
1580 @defer.inlineCallbacks
1581 def _get_event_from_row(self, internal_metadata, js, redacted,
1582 rejected_reason=None):
1583 with Measure(self._clock, "_get_event_from_row"):
1584 d = json.loads(js)
1585 internal_metadata = json.loads(internal_metadata)
1586
1587 if rejected_reason:
1588 rejected_reason = yield self._simple_select_one_onecol(
1589 table="rejections",
1590 keyvalues={"event_id": rejected_reason},
1591 retcol="reason",
1592 desc="_get_event_from_row_rejected_reason",
1593 )
1594
1595 original_ev = FrozenEvent(
1596 d,
1597 internal_metadata_dict=internal_metadata,
1598 rejected_reason=rejected_reason,
1599 )
1600
1601 redacted_event = None
1602 if redacted:
1603 redacted_event = prune_event(original_ev)
1604
1605 redaction_id = yield self._simple_select_one_onecol(
1606 table="redactions",
1607 keyvalues={"redacts": redacted_event.event_id},
1608 retcol="event_id",
1609 desc="_get_event_from_row_redactions",
1610 )
1611
1612 redacted_event.unsigned["redacted_by"] = redaction_id
1613 # Get the redaction event.
1614
1615 because = yield self.get_event(
1616 redaction_id,
1617 check_redacted=False,
1618 allow_none=True,
1619 )
1620
1621 if because:
1622 # It's fine to do add the event directly, since get_pdu_json
1623 # will serialise this field correctly
1624 redacted_event.unsigned["redacted_because"] = because
1625
1626 cache_entry = _EventCacheEntry(
1627 event=original_ev,
1628 redacted_event=redacted_event,
1629 )
1630
1631 self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
1632
1633 defer.returnValue(cache_entry)
1634
1635 @defer.inlineCallbacks
16361322 def count_daily_messages(self):
16371323 """
16381324 Returns an estimate of the number of messages sent in the last day.
20171703 )
20181704 return self.runInteraction("get_all_new_events", get_all_new_events_txn)
20191705
2020 def delete_old_state(self, room_id, topological_ordering):
1706 def purge_history(
1707 self, room_id, topological_ordering, delete_local_events,
1708 ):
1709 """Deletes room history before a certain point
1710
1711 Args:
1712 room_id (str):
1713
1714 topological_ordering (int):
1715 minimum topo ordering to preserve
1716
1717 delete_local_events (bool):
1718 if True, we will delete local events as well as remote ones
1719 (instead of just marking them as outliers and deleting their
1720 state groups).
1721 """
1722
20211723 return self.runInteraction(
2022 "delete_old_state",
2023 self._delete_old_state_txn, room_id, topological_ordering
2024 )
2025
2026 def _delete_old_state_txn(self, txn, room_id, topological_ordering):
2027 """Deletes old room state
2028 """
2029
1724 "purge_history",
1725 self._purge_history_txn, room_id, topological_ordering,
1726 delete_local_events,
1727 )
1728
1729 def _purge_history_txn(
1730 self, txn, room_id, topological_ordering, delete_local_events,
1731 ):
20301732 # Tables that should be pruned:
20311733 # event_auth
20321734 # event_backward_extremities
20471749 # state_groups
20481750 # state_groups_state
20491751
1752 # we will build a temporary table listing the events so that we don't
1753 # have to keep shovelling the list back and forth across the
1754 # connection. Annoyingly the python sqlite driver commits the
1755 # transaction on CREATE, so let's do this first.
1756 #
1757 # furthermore, we might already have the table from a previous (failed)
1758 # purge attempt, so let's drop the table first.
1759
1760 txn.execute("DROP TABLE IF EXISTS events_to_purge")
1761
1762 txn.execute(
1763 "CREATE TEMPORARY TABLE events_to_purge ("
1764 " event_id TEXT NOT NULL,"
1765 " should_delete BOOLEAN NOT NULL"
1766 ")"
1767 )
1768
1769 # create an index on should_delete because later we'll be looking for
1770 # the should_delete / shouldn't_delete subsets
1771 txn.execute(
1772 "CREATE INDEX events_to_purge_should_delete"
1773 " ON events_to_purge(should_delete)",
1774 )
1775
20501776 # First ensure that we're not about to delete all the forward extremeties
20511777 txn.execute(
20521778 "SELECT e.event_id, e.depth FROM events as e "
20671793 400, "topological_ordering is greater than forward extremeties"
20681794 )
20691795
2070 logger.debug("[purge] looking for events to delete")
1796 logger.info("[purge] looking for events to delete")
1797
1798 should_delete_expr = "state_key IS NULL"
1799 should_delete_params = ()
1800 if not delete_local_events:
1801 should_delete_expr += " AND event_id NOT LIKE ?"
1802 should_delete_params += ("%:" + self.hs.hostname, )
1803
1804 should_delete_params += (room_id, topological_ordering)
20711805
20721806 txn.execute(
2073 "SELECT event_id, state_key FROM events"
2074 " LEFT JOIN state_events USING (room_id, event_id)"
2075 " WHERE room_id = ? AND topological_ordering < ?",
2076 (room_id, topological_ordering,)
1807 "INSERT INTO events_to_purge"
1808 " SELECT event_id, %s"
1809 " FROM events AS e LEFT JOIN state_events USING (event_id)"
1810 " WHERE e.room_id = ? AND topological_ordering < ?" % (
1811 should_delete_expr,
1812 ),
1813 should_delete_params,
1814 )
1815 txn.execute(
1816 "SELECT event_id, should_delete FROM events_to_purge"
20771817 )
20781818 event_rows = txn.fetchall()
2079
2080 to_delete = [
2081 (event_id,) for event_id, state_key in event_rows
2082 if state_key is None and not self.hs.is_mine_id(event_id)
2083 ]
20841819 logger.info(
2085 "[purge] found %i events before cutoff, of which %i are remote"
2086 " non-state events to delete", len(event_rows), len(to_delete))
2087
2088 for event_id, state_key in event_rows:
2089 txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
2090
2091 logger.debug("[purge] Finding new backward extremities")
1820 "[purge] found %i events before cutoff, of which %i can be deleted",
1821 len(event_rows), sum(1 for e in event_rows if e[1]),
1822 )
1823
1824 logger.info("[purge] Finding new backward extremities")
20921825
20931826 # We calculate the new entries for the backward extremeties by finding
20941827 # all events that point to events that are to be purged
20951828 txn.execute(
2096 "SELECT DISTINCT e.event_id FROM events as e"
2097 " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
2098 " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
2099 " WHERE e.room_id = ? AND e.topological_ordering < ?"
2100 " AND e2.topological_ordering >= ?",
2101 (room_id, topological_ordering, topological_ordering)
1829 "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
1830 " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
1831 " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
1832 " WHERE e2.topological_ordering >= ?",
1833 (topological_ordering, )
21021834 )
21031835 new_backwards_extrems = txn.fetchall()
21041836
2105 logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
1837 logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
21061838
21071839 txn.execute(
21081840 "DELETE FROM event_backward_extremities WHERE room_id = ?",
21181850 ]
21191851 )
21201852
2121 logger.debug("[purge] finding redundant state groups")
1853 logger.info("[purge] finding redundant state groups")
21221854
21231855 # Get all state groups that are only referenced by events that are
21241856 # to be deleted.
21261858 "SELECT state_group FROM event_to_state_groups"
21271859 " INNER JOIN events USING (event_id)"
21281860 " WHERE state_group IN ("
2129 " SELECT DISTINCT state_group FROM events"
1861 " SELECT DISTINCT state_group FROM events_to_purge"
21301862 " INNER JOIN event_to_state_groups USING (event_id)"
2131 " WHERE room_id = ? AND topological_ordering < ?"
21321863 " )"
21331864 " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
2134 (room_id, topological_ordering, topological_ordering)
1865 (topological_ordering, )
21351866 )
21361867
21371868 state_rows = txn.fetchall()
2138 logger.debug("[purge] found %i redundant state groups", len(state_rows))
1869 logger.info("[purge] found %i redundant state groups", len(state_rows))
21391870
21401871 # make a set of the redundant state groups, so that we can look them up
21411872 # efficiently
21421873 state_groups_to_delete = set([sg for sg, in state_rows])
21431874
21441875 # Now we get all the state groups that rely on these state groups
2145 logger.debug("[purge] finding state groups which depend on redundant"
2146 " state groups")
1876 logger.info("[purge] finding state groups which depend on redundant"
1877 " state groups")
21471878 remaining_state_groups = []
21481879 for i in xrange(0, len(state_rows), 100):
21491880 chunk = [sg for sg, in state_rows[i:i + 100]]
21681899 # Now we turn the state groups that reference to-be-deleted state
21691900 # groups to non delta versions.
21701901 for sg in remaining_state_groups:
2171 logger.debug("[purge] de-delta-ing remaining state group %s", sg)
1902 logger.info("[purge] de-delta-ing remaining state group %s", sg)
21721903 curr_state = self._get_state_groups_from_groups_txn(
21731904 txn, [sg], types=None
21741905 )
22051936 ],
22061937 )
22071938
2208 logger.debug("[purge] removing redundant state groups")
1939 logger.info("[purge] removing redundant state groups")
22091940 txn.executemany(
22101941 "DELETE FROM state_groups_state WHERE state_group = ?",
22111942 state_rows
22151946 state_rows
22161947 )
22171948
2218 # Delete all non-state
2219 logger.debug("[purge] removing events from event_to_state_groups")
2220 txn.executemany(
2221 "DELETE FROM event_to_state_groups WHERE event_id = ?",
2222 [(event_id,) for event_id, _ in event_rows]
2223 )
2224
2225 logger.debug("[purge] updating room_depth")
1949 logger.info("[purge] removing events from event_to_state_groups")
22261950 txn.execute(
2227 "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
2228 (topological_ordering, room_id,)
2229 )
1951 "DELETE FROM event_to_state_groups "
1952 "WHERE event_id IN (SELECT event_id from events_to_purge)"
1953 )
1954 for event_id, _ in event_rows:
1955 txn.call_after(self._get_state_group_for_event.invalidate, (
1956 event_id,
1957 ))
22301958
22311959 # Delete all remote non-state events
22321960 for table in (
22381966 "event_edge_hashes",
22391967 "event_edges",
22401968 "event_forward_extremities",
2241 "event_push_actions",
22421969 "event_reference_hashes",
22431970 "event_search",
22441971 "event_signatures",
22451972 "rejections",
22461973 ):
2247 logger.debug("[purge] removing remote non-state events from %s", table)
2248
2249 txn.executemany(
2250 "DELETE FROM %s WHERE event_id = ?" % (table,),
2251 to_delete
1974 logger.info("[purge] removing events from %s", table)
1975
1976 txn.execute(
1977 "DELETE FROM %s WHERE event_id IN ("
1978 " SELECT event_id FROM events_to_purge WHERE should_delete"
1979 ")" % (table,),
1980 )
1981
1982 # event_push_actions lacks an index on event_id, and has one on
1983 # (room_id, event_id) instead.
1984 for table in (
1985 "event_push_actions",
1986 ):
1987 logger.info("[purge] removing events from %s", table)
1988
1989 txn.execute(
1990 "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
1991 " SELECT event_id FROM events_to_purge WHERE should_delete"
1992 ")" % (table,),
1993 (room_id, )
22521994 )
22531995
22541996 # Mark all state and own events as outliers
2255 logger.debug("[purge] marking remaining events as outliers")
2256 txn.executemany(
1997 logger.info("[purge] marking remaining events as outliers")
1998 txn.execute(
22571999 "UPDATE events SET outlier = ?"
2258 " WHERE event_id = ?",
2259 [
2260 (True, event_id,) for event_id, state_key in event_rows
2261 if state_key is not None or self.hs.is_mine_id(event_id)
2262 ]
2000 " WHERE event_id IN ("
2001 " SELECT event_id FROM events_to_purge "
2002 " WHERE NOT should_delete"
2003 ")",
2004 (True,),
2005 )
2006
2007 # synapse tries to take out an exclusive lock on room_depth whenever it
2008 # persists events (because upsert), and once we run this update, we
2009 # will block that for the rest of our transaction.
2010 #
2011 # So, let's stick it at the end so that we don't block event
2012 # persistence.
2013 logger.info("[purge] updating room_depth")
2014 txn.execute(
2015 "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
2016 (topological_ordering, room_id,)
2017 )
2018
2019 # finally, drop the temp table. this will commit the txn in sqlite,
2020 # so make sure to keep this actually last.
2021 txn.execute(
2022 "DROP TABLE events_to_purge"
22632023 )
22642024
22652025 logger.info("[purge] done")
22722032 to_2, so_2 = yield self._get_event_ordering(event_id2)
22732033 defer.returnValue((to_1, so_1) > (to_2, so_2))
22742034
2275 @defer.inlineCallbacks
2035 @cachedInlineCallbacks(max_entries=5000)
22762036 def _get_event_ordering(self, event_id):
22772037 res = yield self._simple_select_one(
22782038 table="events",
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 from ._base import SQLBaseStore
15
16 from twisted.internet import defer, reactor
17
18 from synapse.events import FrozenEvent
19 from synapse.events.utils import prune_event
20
21 from synapse.util.logcontext import (
22 preserve_fn, PreserveLoggingContext, make_deferred_yieldable
23 )
24 from synapse.util.metrics import Measure
25 from synapse.api.errors import SynapseError
26
27 from collections import namedtuple
28
29 import logging
30 import simplejson as json
31
32 # these are only included to make the type annotations work
33 from synapse.events import EventBase # noqa: F401
34 from synapse.events.snapshot import EventContext # noqa: F401
35
36 logger = logging.getLogger(__name__)
37
38
39 # These values are used in the `enqueus_event` and `_do_fetch` methods to
40 # control how we batch/bulk fetch events from the database.
41 # The values are plucked out of thing air to make initial sync run faster
42 # on jki.re
43 # TODO: Make these configurable.
44 EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
45 EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
46 EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
47
48
49 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
50
51
52 class EventsWorkerStore(SQLBaseStore):
53
54 @defer.inlineCallbacks
55 def get_event(self, event_id, check_redacted=True,
56 get_prev_content=False, allow_rejected=False,
57 allow_none=False):
58 """Get an event from the database by event_id.
59
60 Args:
61 event_id (str): The event_id of the event to fetch
62 check_redacted (bool): If True, check if event has been redacted
63 and redact it.
64 get_prev_content (bool): If True and event is a state event,
65 include the previous states content in the unsigned field.
66 allow_rejected (bool): If True return rejected events.
67 allow_none (bool): If True, return None if no event found, if
68 False throw an exception.
69
70 Returns:
71 Deferred : A FrozenEvent.
72 """
73 events = yield self._get_events(
74 [event_id],
75 check_redacted=check_redacted,
76 get_prev_content=get_prev_content,
77 allow_rejected=allow_rejected,
78 )
79
80 if not events and not allow_none:
81 raise SynapseError(404, "Could not find event %s" % (event_id,))
82
83 defer.returnValue(events[0] if events else None)
84
85 @defer.inlineCallbacks
86 def get_events(self, event_ids, check_redacted=True,
87 get_prev_content=False, allow_rejected=False):
88 """Get events from the database
89
90 Args:
91 event_ids (list): The event_ids of the events to fetch
92 check_redacted (bool): If True, check if event has been redacted
93 and redact it.
94 get_prev_content (bool): If True and event is a state event,
95 include the previous states content in the unsigned field.
96 allow_rejected (bool): If True return rejected events.
97
98 Returns:
99 Deferred : Dict from event_id to event.
100 """
101 events = yield self._get_events(
102 event_ids,
103 check_redacted=check_redacted,
104 get_prev_content=get_prev_content,
105 allow_rejected=allow_rejected,
106 )
107
108 defer.returnValue({e.event_id: e for e in events})
109
110 @defer.inlineCallbacks
111 def _get_events(self, event_ids, check_redacted=True,
112 get_prev_content=False, allow_rejected=False):
113 if not event_ids:
114 defer.returnValue([])
115
116 event_id_list = event_ids
117 event_ids = set(event_ids)
118
119 event_entry_map = self._get_events_from_cache(
120 event_ids,
121 allow_rejected=allow_rejected,
122 )
123
124 missing_events_ids = [e for e in event_ids if e not in event_entry_map]
125
126 if missing_events_ids:
127 missing_events = yield self._enqueue_events(
128 missing_events_ids,
129 check_redacted=check_redacted,
130 allow_rejected=allow_rejected,
131 )
132
133 event_entry_map.update(missing_events)
134
135 events = []
136 for event_id in event_id_list:
137 entry = event_entry_map.get(event_id, None)
138 if not entry:
139 continue
140
141 if allow_rejected or not entry.event.rejected_reason:
142 if check_redacted and entry.redacted_event:
143 event = entry.redacted_event
144 else:
145 event = entry.event
146
147 events.append(event)
148
149 if get_prev_content:
150 if "replaces_state" in event.unsigned:
151 prev = yield self.get_event(
152 event.unsigned["replaces_state"],
153 get_prev_content=False,
154 allow_none=True,
155 )
156 if prev:
157 event.unsigned = dict(event.unsigned)
158 event.unsigned["prev_content"] = prev.content
159 event.unsigned["prev_sender"] = prev.sender
160
161 defer.returnValue(events)
162
163 def _invalidate_get_event_cache(self, event_id):
164 self._get_event_cache.invalidate((event_id,))
165
166 def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
167 """Fetch events from the caches
168
169 Args:
170 events (list(str)): list of event_ids to fetch
171 allow_rejected (bool): Whether to teturn events that were rejected
172 update_metrics (bool): Whether to update the cache hit ratio metrics
173
174 Returns:
175 dict of event_id -> _EventCacheEntry for each event_id in cache. If
176 allow_rejected is `False` then there will still be an entry but it
177 will be `None`
178 """
179 event_map = {}
180
181 for event_id in events:
182 ret = self._get_event_cache.get(
183 (event_id,), None,
184 update_metrics=update_metrics,
185 )
186 if not ret:
187 continue
188
189 if allow_rejected or not ret.event.rejected_reason:
190 event_map[event_id] = ret
191 else:
192 event_map[event_id] = None
193
194 return event_map
195
196 def _do_fetch(self, conn):
197 """Takes a database connection and waits for requests for events from
198 the _event_fetch_list queue.
199 """
200 event_list = []
201 i = 0
202 while True:
203 try:
204 with self._event_fetch_lock:
205 event_list = self._event_fetch_list
206 self._event_fetch_list = []
207
208 if not event_list:
209 single_threaded = self.database_engine.single_threaded
210 if single_threaded or i > EVENT_QUEUE_ITERATIONS:
211 self._event_fetch_ongoing -= 1
212 return
213 else:
214 self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
215 i += 1
216 continue
217 i = 0
218
219 event_id_lists = zip(*event_list)[0]
220 event_ids = [
221 item for sublist in event_id_lists for item in sublist
222 ]
223
224 rows = self._new_transaction(
225 conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
226 )
227
228 row_dict = {
229 r["event_id"]: r
230 for r in rows
231 }
232
233 # We only want to resolve deferreds from the main thread
234 def fire(lst, res):
235 for ids, d in lst:
236 if not d.called:
237 try:
238 with PreserveLoggingContext():
239 d.callback([
240 res[i]
241 for i in ids
242 if i in res
243 ])
244 except Exception:
245 logger.exception("Failed to callback")
246 with PreserveLoggingContext():
247 reactor.callFromThread(fire, event_list, row_dict)
248 except Exception as e:
249 logger.exception("do_fetch")
250
251 # We only want to resolve deferreds from the main thread
252 def fire(evs):
253 for _, d in evs:
254 if not d.called:
255 with PreserveLoggingContext():
256 d.errback(e)
257
258 if event_list:
259 with PreserveLoggingContext():
260 reactor.callFromThread(fire, event_list)
261
262 @defer.inlineCallbacks
263 def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
264 """Fetches events from the database using the _event_fetch_list. This
265 allows batch and bulk fetching of events - it allows us to fetch events
266 without having to create a new transaction for each request for events.
267 """
268 if not events:
269 defer.returnValue({})
270
271 events_d = defer.Deferred()
272 with self._event_fetch_lock:
273 self._event_fetch_list.append(
274 (events, events_d)
275 )
276
277 self._event_fetch_lock.notify()
278
279 if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
280 self._event_fetch_ongoing += 1
281 should_start = True
282 else:
283 should_start = False
284
285 if should_start:
286 with PreserveLoggingContext():
287 self.runWithConnection(
288 self._do_fetch
289 )
290
291 logger.debug("Loading %d events", len(events))
292 with PreserveLoggingContext():
293 rows = yield events_d
294 logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
295
296 if not allow_rejected:
297 rows[:] = [r for r in rows if not r["rejects"]]
298
299 res = yield make_deferred_yieldable(defer.gatherResults(
300 [
301 preserve_fn(self._get_event_from_row)(
302 row["internal_metadata"], row["json"], row["redacts"],
303 rejected_reason=row["rejects"],
304 )
305 for row in rows
306 ],
307 consumeErrors=True
308 ))
309
310 defer.returnValue({
311 e.event.event_id: e
312 for e in res if e
313 })
314
315 def _fetch_event_rows(self, txn, events):
316 rows = []
317 N = 200
318 for i in range(1 + len(events) / N):
319 evs = events[i * N:(i + 1) * N]
320 if not evs:
321 break
322
323 sql = (
324 "SELECT "
325 " e.event_id as event_id, "
326 " e.internal_metadata,"
327 " e.json,"
328 " r.redacts as redacts,"
329 " rej.event_id as rejects "
330 " FROM event_json as e"
331 " LEFT JOIN rejections as rej USING (event_id)"
332 " LEFT JOIN redactions as r ON e.event_id = r.redacts"
333 " WHERE e.event_id IN (%s)"
334 ) % (",".join(["?"] * len(evs)),)
335
336 txn.execute(sql, evs)
337 rows.extend(self.cursor_to_dict(txn))
338
339 return rows
340
341 @defer.inlineCallbacks
342 def _get_event_from_row(self, internal_metadata, js, redacted,
343 rejected_reason=None):
344 with Measure(self._clock, "_get_event_from_row"):
345 d = json.loads(js)
346 internal_metadata = json.loads(internal_metadata)
347
348 if rejected_reason:
349 rejected_reason = yield self._simple_select_one_onecol(
350 table="rejections",
351 keyvalues={"event_id": rejected_reason},
352 retcol="reason",
353 desc="_get_event_from_row_rejected_reason",
354 )
355
356 original_ev = FrozenEvent(
357 d,
358 internal_metadata_dict=internal_metadata,
359 rejected_reason=rejected_reason,
360 )
361
362 redacted_event = None
363 if redacted:
364 redacted_event = prune_event(original_ev)
365
366 redaction_id = yield self._simple_select_one_onecol(
367 table="redactions",
368 keyvalues={"redacts": redacted_event.event_id},
369 retcol="event_id",
370 desc="_get_event_from_row_redactions",
371 )
372
373 redacted_event.unsigned["redacted_by"] = redaction_id
374 # Get the redaction event.
375
376 because = yield self.get_event(
377 redaction_id,
378 check_redacted=False,
379 allow_none=True,
380 )
381
382 if because:
383 # It's fine to do add the event directly, since get_pdu_json
384 # will serialise this field correctly
385 redacted_event.unsigned["redacted_because"] = because
386
387 cache_entry = _EventCacheEntry(
388 event=original_ev,
389 redacted_event=redacted_event,
390 )
391
392 self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
393
394 defer.returnValue(cache_entry)
3434 keyvalues={
3535 "group_id": group_id,
3636 },
37 retcols=("name", "short_description", "long_description", "avatar_url",),
37 retcols=(
38 "name", "short_description", "long_description", "avatar_url", "is_public"
39 ),
3840 allow_none=True,
3941 desc="is_user_in_group",
4042 )
5153 return self._simple_select_list(
5254 table="group_users",
5355 keyvalues=keyvalues,
54 retcols=("user_id", "is_public",),
56 retcols=("user_id", "is_public", "is_admin",),
5557 desc="get_users_in_group",
5658 )
5759
852854 "is_public": is_public,
853855 },
854856 desc="add_room_to_group",
857 )
858
859 def update_room_in_group_visibility(self, group_id, room_id, is_public):
860 return self._simple_update(
861 table="group_rooms",
862 keyvalues={
863 "group_id": group_id,
864 "room_id": room_id,
865 },
866 updatevalues={
867 "is_public": is_public,
868 },
869 desc="update_room_in_group_visibility",
855870 )
856871
857872 def remove_room_from_group(self, group_id, room_id):
10251040 "avatar_url": avatar_url,
10261041 "short_description": short_description,
10271042 "long_description": long_description,
1043 "is_public": True,
10281044 },
10291045 desc="create_group",
10301046 )
10831099 "attestation_json": json.dumps(attestation)
10841100 },
10851101 desc="update_remote_attestion",
1102 )
1103
1104 def remove_attestation_renewal(self, group_id, user_id):
1105 """Remove an attestation that we thought we should renew, but actually
1106 shouldn't. Ideally this would never get called as we would never
1107 incorrectly try and do attestations for local users on local groups.
1108
1109 Args:
1110 group_id (str)
1111 user_id (str)
1112 """
1113 return self._simple_delete(
1114 table="group_attestations_renewals",
1115 keyvalues={
1116 "group_id": group_id,
1117 "user_id": user_id,
1118 },
1119 desc="remove_attestation_renewal",
10861120 )
10871121
10881122 @defer.inlineCallbacks
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
15 from ._base import SQLBaseStore
16
17
18 class MediaRepositoryStore(SQLBaseStore):
14 from synapse.storage.background_updates import BackgroundUpdateStore
15
16
17 class MediaRepositoryStore(BackgroundUpdateStore):
1918 """Persistence for attachments and avatars"""
2019
21 def get_default_thumbnails(self, top_level_type, sub_type):
22 return []
20 def __init__(self, db_conn, hs):
21 super(MediaRepositoryStore, self).__init__(db_conn, hs)
22
23 self.register_background_index_update(
24 update_name='local_media_repository_url_idx',
25 index_name='local_media_repository_url_idx',
26 table='local_media_repository',
27 columns=['created_ts'],
28 where_clause='url_cache IS NOT NULL',
29 )
2330
2431 def get_local_media(self, media_id):
2532 """Get the metadata for a local piece of media
165172 desc="store_cached_remote_media",
166173 )
167174
168 def update_cached_last_access_time(self, origin_id_tuples, time_ts):
175 def update_cached_last_access_time(self, local_media, remote_media, time_ms):
176 """Updates the last access time of the given media
177
178 Args:
179 local_media (iterable[str]): Set of media_ids
180 remote_media (iterable[(str, str)]): Set of (server_name, media_id)
181 time_ms: Current time in milliseconds
182 """
169183 def update_cache_txn(txn):
170184 sql = (
171185 "UPDATE remote_media_cache SET last_access_ts = ?"
173187 )
174188
175189 txn.executemany(sql, (
176 (time_ts, media_origin, media_id)
177 for media_origin, media_id in origin_id_tuples
190 (time_ms, media_origin, media_id)
191 for media_origin, media_id in remote_media
192 ))
193
194 sql = (
195 "UPDATE local_media_repository SET last_access_ts = ?"
196 " WHERE media_id = ?"
197 )
198
199 txn.executemany(sql, (
200 (time_ms, media_id)
201 for media_id in local_media
178202 ))
179203
180204 return self.runInteraction("update_cached_last_access_time", update_cache_txn)
253277 return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
254278
255279 def delete_url_cache(self, media_ids):
280 if len(media_ids) == 0:
281 return
282
256283 sql = (
257284 "DELETE FROM local_media_repository_url_cache"
258285 " WHERE media_id = ?"
280307 )
281308
282309 def delete_url_cache_media(self, media_ids):
310 if len(media_ids) == 0:
311 return
312
283313 def _delete_url_cache_media_txn(txn):
284314 sql = (
285315 "DELETE FROM local_media_repository"
2424
2525 # Remember to update this number every time a change is made to database
2626 # schema files, so the users will be informed on server restarts.
27 SCHEMA_VERSION = 45
27 SCHEMA_VERSION = 47
2828
2929 dir_path = os.path.abspath(os.path.dirname(__file__))
3030
4343
4444 If `config` is None then prepare_database will assert that no upgrade is
4545 necessary, *or* will create a fresh database if the database is empty.
46
47 Args:
48 db_conn:
49 database_engine:
50 config (synapse.config.homeserver.HomeServerConfig|None):
51 application config, or None if we are connecting to an existing
52 database which we expect to be configured already
4653 """
4754 try:
4855 cur = db_conn.cursor()
6370 else:
6471 _setup_new_database(cur, database_engine)
6572
73 # check if any of our configured dynamic modules want a database
74 if config is not None:
75 _apply_module_schemas(cur, database_engine, config)
76
6677 cur.close()
6778 db_conn.commit()
68 except:
79 except Exception:
6980 db_conn.rollback()
7081 raise
7182
282293 )
283294
284295
296 def _apply_module_schemas(txn, database_engine, config):
297 """Apply the module schemas for the dynamic modules, if any
298
299 Args:
300 cur: database cursor
301 database_engine: synapse database engine class
302 config (synapse.config.homeserver.HomeServerConfig):
303 application config
304 """
305 for (mod, _config) in config.password_providers:
306 if not hasattr(mod, 'get_db_schema_files'):
307 continue
308 modname = ".".join((mod.__module__, mod.__name__))
309 _apply_module_schema_files(
310 txn, database_engine, modname, mod.get_db_schema_files(),
311 )
312
313
314 def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
315 """Apply the module schemas for a single module
316
317 Args:
318 cur: database cursor
319 database_engine: synapse database engine class
320 modname (str): fully qualified name of the module
321 names_and_streams (Iterable[(str, file)]): the names and streams of
322 schemas to be applied
323 """
324 cur.execute(
325 database_engine.convert_param_style(
326 "SELECT file FROM applied_module_schemas WHERE module_name = ?"
327 ),
328 (modname,)
329 )
330 applied_deltas = set(d for d, in cur)
331 for (name, stream) in names_and_streams:
332 if name in applied_deltas:
333 continue
334
335 root_name, ext = os.path.splitext(name)
336 if ext != '.sql':
337 raise PrepareDatabaseException(
338 "only .sql files are currently supported for module schemas",
339 )
340
341 logger.info("applying schema %s for %s", name, modname)
342 for statement in get_statements(stream):
343 cur.execute(statement)
344
345 # Mark as done.
346 cur.execute(
347 database_engine.convert_param_style(
348 "INSERT INTO applied_module_schemas (module_name, file)"
349 " VALUES (?,?)",
350 ),
351 (modname, name)
352 )
353
354
285355 def get_statements(f):
286356 statement_buffer = ""
287357 in_comment = False # If we're in a /* ... */ style comment
1414
1515 from twisted.internet import defer
1616
17 from synapse.storage.roommember import ProfileInfo
18 from synapse.api.errors import StoreError
19
1720 from ._base import SQLBaseStore
1821
1922
20 class ProfileStore(SQLBaseStore):
21 def create_profile(self, user_localpart):
22 return self._simple_insert(
23 table="profiles",
24 values={"user_id": user_localpart},
25 desc="create_profile",
23 class ProfileWorkerStore(SQLBaseStore):
24 @defer.inlineCallbacks
25 def get_profileinfo(self, user_localpart):
26 try:
27 profile = yield self._simple_select_one(
28 table="profiles",
29 keyvalues={"user_id": user_localpart},
30 retcols=("displayname", "avatar_url"),
31 desc="get_profileinfo",
32 )
33 except StoreError as e:
34 if e.code == 404:
35 # no match
36 defer.returnValue(ProfileInfo(None, None))
37 return
38 else:
39 raise
40
41 defer.returnValue(
42 ProfileInfo(
43 avatar_url=profile['avatar_url'],
44 display_name=profile['displayname'],
45 )
2646 )
2747
2848 def get_profile_displayname(self, user_localpart):
3353 desc="get_profile_displayname",
3454 )
3555
36 def set_profile_displayname(self, user_localpart, new_displayname):
37 return self._simple_update_one(
38 table="profiles",
39 keyvalues={"user_id": user_localpart},
40 updatevalues={"displayname": new_displayname},
41 desc="set_profile_displayname",
42 )
43
4456 def get_profile_avatar_url(self, user_localpart):
4557 return self._simple_select_one_onecol(
4658 table="profiles",
4759 keyvalues={"user_id": user_localpart},
4860 retcol="avatar_url",
4961 desc="get_profile_avatar_url",
50 )
51
52 def set_profile_avatar_url(self, user_localpart, new_avatar_url):
53 return self._simple_update_one(
54 table="profiles",
55 keyvalues={"user_id": user_localpart},
56 updatevalues={"avatar_url": new_avatar_url},
57 desc="set_profile_avatar_url",
5862 )
5963
6064 def get_from_remote_profile_cache(self, user_id):
6468 retcols=("displayname", "avatar_url",),
6569 allow_none=True,
6670 desc="get_from_remote_profile_cache",
71 )
72
73
74 class ProfileStore(ProfileWorkerStore):
75 def create_profile(self, user_localpart):
76 return self._simple_insert(
77 table="profiles",
78 values={"user_id": user_localpart},
79 desc="create_profile",
80 )
81
82 def set_profile_displayname(self, user_localpart, new_displayname):
83 return self._simple_update_one(
84 table="profiles",
85 keyvalues={"user_id": user_localpart},
86 updatevalues={"displayname": new_displayname},
87 desc="set_profile_displayname",
88 )
89
90 def set_profile_avatar_url(self, user_localpart, new_avatar_url):
91 return self._simple_update_one(
92 table="profiles",
93 keyvalues={"user_id": user_localpart},
94 updatevalues={"avatar_url": new_avatar_url},
95 desc="set_profile_avatar_url",
6796 )
6897
6998 def add_remote_profile_cache(self, user_id, displayname, avatar_url):
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415
1516 from ._base import SQLBaseStore
17 from synapse.storage.appservice import ApplicationServiceWorkerStore
18 from synapse.storage.pusher import PusherWorkerStore
19 from synapse.storage.receipts import ReceiptsWorkerStore
20 from synapse.storage.roommember import RoomMemberWorkerStore
1621 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
22 from synapse.util.caches.stream_change_cache import StreamChangeCache
1723 from synapse.push.baserules import list_with_base_rules
1824 from synapse.api.constants import EventTypes
1925 from twisted.internet import defer
2026
27 import abc
2128 import logging
2229 import simplejson as json
2330
4754 return rules
4855
4956
50 class PushRuleStore(SQLBaseStore):
57 class PushRulesWorkerStore(ApplicationServiceWorkerStore,
58 ReceiptsWorkerStore,
59 PusherWorkerStore,
60 RoomMemberWorkerStore,
61 SQLBaseStore):
62 """This is an abstract base class where subclasses must implement
63 `get_max_push_rules_stream_id` which can be called in the initializer.
64 """
65
66 # This ABCMeta metaclass ensures that we cannot be instantiated without
67 # the abstract methods being implemented.
68 __metaclass__ = abc.ABCMeta
69
70 def __init__(self, db_conn, hs):
71 super(PushRulesWorkerStore, self).__init__(db_conn, hs)
72
73 push_rules_prefill, push_rules_id = self._get_cache_dict(
74 db_conn, "push_rules_stream",
75 entity_column="user_id",
76 stream_column="stream_id",
77 max_value=self.get_max_push_rules_stream_id(),
78 )
79
80 self.push_rules_stream_cache = StreamChangeCache(
81 "PushRulesStreamChangeCache", push_rules_id,
82 prefilled_cache=push_rules_prefill,
83 )
84
85 @abc.abstractmethod
86 def get_max_push_rules_stream_id(self):
87 """Get the position of the push rules stream.
88
89 Returns:
90 int
91 """
92 raise NotImplementedError()
93
5194 @cachedInlineCallbacks(max_entries=5000)
5295 def get_push_rules_for_user(self, user_id):
5396 rows = yield self._simple_select_list(
88131 r['rule_id']: False if r['enabled'] == 0 else True for r in results
89132 })
90133
134 def have_push_rules_changed_for_user(self, user_id, last_id):
135 if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
136 return defer.succeed(False)
137 else:
138 def have_push_rules_changed_txn(txn):
139 sql = (
140 "SELECT COUNT(stream_id) FROM push_rules_stream"
141 " WHERE user_id = ? AND ? < stream_id"
142 )
143 txn.execute(sql, (user_id, last_id))
144 count, = txn.fetchone()
145 return bool(count)
146 return self.runInteraction(
147 "have_push_rules_changed", have_push_rules_changed_txn
148 )
149
91150 @cachedList(cached_method_name="get_push_rules_for_user",
92151 list_name="user_ids", num_args=1, inlineCallbacks=True)
93152 def bulk_get_push_rules(self, user_ids):
227286 results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
228287 defer.returnValue(results)
229288
289
290 class PushRuleStore(PushRulesWorkerStore):
230291 @defer.inlineCallbacks
231292 def add_push_rule(
232293 self, user_id, rule_id, priority_class, conditions, actions,
525586 room stream ordering it corresponds to."""
526587 return self._push_rules_stream_id_gen.get_current_token()
527588
528 def have_push_rules_changed_for_user(self, user_id, last_id):
529 if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
530 return defer.succeed(False)
531 else:
532 def have_push_rules_changed_txn(txn):
533 sql = (
534 "SELECT COUNT(stream_id) FROM push_rules_stream"
535 " WHERE user_id = ? AND ? < stream_id"
536 )
537 txn.execute(sql, (user_id, last_id))
538 count, = txn.fetchone()
539 return bool(count)
540 return self.runInteraction(
541 "have_push_rules_changed", have_push_rules_changed_txn
542 )
589 def get_max_push_rules_stream_id(self):
590 return self.get_push_rules_stream_token()[0]
543591
544592
545593 class RuleNotFoundException(Exception):
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
2627 logger = logging.getLogger(__name__)
2728
2829
29 class PusherStore(SQLBaseStore):
30 class PusherWorkerStore(SQLBaseStore):
3031 def _decode_pushers_rows(self, rows):
3132 for r in rows:
3233 dataJson = r['data']
101102 rows = yield self.runInteraction("get_all_pushers", get_pushers)
102103 defer.returnValue(rows)
103104
104 def get_pushers_stream_token(self):
105 return self._pushers_id_gen.get_current_token()
106
107105 def get_all_updated_pushers(self, last_id, current_id, limit):
108106 if last_id == current_id:
109107 return defer.succeed(([], []))
197195
198196 defer.returnValue(result)
199197
198
199 class PusherStore(PusherWorkerStore):
200 def get_pushers_stream_token(self):
201 return self._pushers_id_gen.get_current_token()
202
200203 @defer.inlineCallbacks
201204 def add_pusher(self, user_id, access_token, kind, app_id,
202205 app_display_name, device_display_name,
203206 pushkey, pushkey_ts, lang, data, last_stream_ordering,
204207 profile_tag=""):
205208 with self._pushers_id_gen.get_next() as stream_id:
206 def f(txn):
207 newly_inserted = self._simple_upsert_txn(
208 txn,
209 "pushers",
210 {
211 "app_id": app_id,
212 "pushkey": pushkey,
213 "user_name": user_id,
214 },
215 {
216 "access_token": access_token,
217 "kind": kind,
218 "app_display_name": app_display_name,
219 "device_display_name": device_display_name,
220 "ts": pushkey_ts,
221 "lang": lang,
222 "data": encode_canonical_json(data),
223 "last_stream_ordering": last_stream_ordering,
224 "profile_tag": profile_tag,
225 "id": stream_id,
226 },
209 # no need to lock because `pushers` has a unique key on
210 # (app_id, pushkey, user_name) so _simple_upsert will retry
211 newly_inserted = yield self._simple_upsert(
212 table="pushers",
213 keyvalues={
214 "app_id": app_id,
215 "pushkey": pushkey,
216 "user_name": user_id,
217 },
218 values={
219 "access_token": access_token,
220 "kind": kind,
221 "app_display_name": app_display_name,
222 "device_display_name": device_display_name,
223 "ts": pushkey_ts,
224 "lang": lang,
225 "data": encode_canonical_json(data),
226 "last_stream_ordering": last_stream_ordering,
227 "profile_tag": profile_tag,
228 "id": stream_id,
229 },
230 desc="add_pusher",
231 lock=False,
232 )
233
234 if newly_inserted:
235 self.runInteraction(
236 "add_pusher",
237 self._invalidate_cache_and_stream,
238 self.get_if_user_has_pusher, (user_id,)
227239 )
228 if newly_inserted:
229 # get_if_user_has_pusher only cares if the user has
230 # at least *one* pusher.
231 txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
232
233 yield self.runInteraction("add_pusher", f)
234240
235241 @defer.inlineCallbacks
236242 def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
237243 def delete_pusher_txn(txn, stream_id):
238 txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
244 self._invalidate_cache_and_stream(
245 txn, self.get_if_user_has_pusher, (user_id,)
246 )
239247
240248 self._simple_delete_one_txn(
241249 txn,
242250 "pushers",
243251 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
244252 )
245 self._simple_upsert_txn(
253
254 # it's possible for us to end up with duplicate rows for
255 # (app_id, pushkey, user_id) at different stream_ids, but that
256 # doesn't really matter.
257 self._simple_insert_txn(
246258 txn,
247 "deleted_pushers",
248 {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
249 {"stream_id": stream_id},
259 table="deleted_pushers",
260 values={
261 "stream_id": stream_id,
262 "app_id": app_id,
263 "pushkey": pushkey,
264 "user_id": user_id,
265 },
250266 )
251267
252268 with self._pushers_id_gen.get_next() as stream_id:
309325
310326 @defer.inlineCallbacks
311327 def set_throttle_params(self, pusher_id, room_id, params):
328 # no need to lock because `pusher_throttle` has a primary key on
329 # (pusher, room_id) so _simple_upsert will retry
312330 yield self._simple_upsert(
313331 "pusher_throttle",
314332 {"pusher": pusher_id, "room_id": room_id},
315333 params,
316 desc="set_throttle_params"
317 )
334 desc="set_throttle_params",
335 lock=False,
336 )
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415
1516 from ._base import SQLBaseStore
17 from .util.id_generators import StreamIdGenerator
1618 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
1719 from synapse.util.caches.stream_change_cache import StreamChangeCache
1820
1921 from twisted.internet import defer
2022
23 import abc
2124 import logging
22 import ujson as json
25 import simplejson as json
2326
2427
2528 logger = logging.getLogger(__name__)
2629
2730
28 class ReceiptsStore(SQLBaseStore):
29 def __init__(self, hs):
30 super(ReceiptsStore, self).__init__(hs)
31 class ReceiptsWorkerStore(SQLBaseStore):
32 """This is an abstract base class where subclasses must implement
33 `get_max_receipt_stream_id` which can be called in the initializer.
34 """
35
36 # This ABCMeta metaclass ensures that we cannot be instantiated without
37 # the abstract methods being implemented.
38 __metaclass__ = abc.ABCMeta
39
40 def __init__(self, db_conn, hs):
41 super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
3142
3243 self._receipts_stream_cache = StreamChangeCache(
33 "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
34 )
44 "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
45 )
46
47 @abc.abstractmethod
48 def get_max_receipt_stream_id(self):
49 """Get the current max stream ID for receipts stream
50
51 Returns:
52 int
53 """
54 raise NotImplementedError()
3555
3656 @cachedInlineCallbacks()
3757 def get_users_with_read_receipts_in_room(self, room_id):
3858 receipts = yield self.get_receipts_for_room(room_id, "m.read")
3959 defer.returnValue(set(r['user_id'] for r in receipts))
40
41 def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
42 user_id):
43 if receipt_type != "m.read":
44 return
45
46 # Returns an ObservableDeferred
47 res = self.get_users_with_read_receipts_in_room.cache.get(
48 room_id, None, update_metrics=False,
49 )
50
51 if res:
52 if isinstance(res, defer.Deferred) and res.called:
53 res = res.result
54 if user_id in res:
55 # We'd only be adding to the set, so no point invalidating if the
56 # user is already there
57 return
58
59 self.get_users_with_read_receipts_in_room.invalidate((room_id,))
6060
6161 @cached(num_args=2)
6262 def get_receipts_for_room(self, room_id, receipt_type):
269269 }
270270 defer.returnValue(results)
271271
272 def get_max_receipt_stream_id(self):
273 return self._receipts_id_gen.get_current_token()
274
275 def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
276 user_id, event_id, data, stream_id):
277 txn.call_after(
278 self.get_receipts_for_room.invalidate, (room_id, receipt_type)
279 )
280 txn.call_after(
281 self._invalidate_get_users_with_receipts_in_room,
282 room_id, receipt_type, user_id,
283 )
284 txn.call_after(
285 self.get_receipts_for_user.invalidate, (user_id, receipt_type)
286 )
287 # FIXME: This shouldn't invalidate the whole cache
288 txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
289
290 txn.call_after(
291 self._receipts_stream_cache.entity_has_changed,
292 room_id, stream_id
293 )
294
295 txn.call_after(
296 self.get_last_receipt_event_id_for_user.invalidate,
297 (user_id, room_id, receipt_type)
298 )
299
300 res = self._simple_select_one_txn(
301 txn,
302 table="events",
303 retcols=["topological_ordering", "stream_ordering"],
304 keyvalues={"event_id": event_id},
305 allow_none=True
306 )
307
308 topological_ordering = int(res["topological_ordering"]) if res else None
309 stream_ordering = int(res["stream_ordering"]) if res else None
310
311 # We don't want to clobber receipts for more recent events, so we
312 # have to compare orderings of existing receipts
313 sql = (
314 "SELECT topological_ordering, stream_ordering, event_id FROM events"
315 " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
316 " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
317 )
318
319 txn.execute(sql, (room_id, receipt_type, user_id))
320
321 if topological_ordering:
322 for to, so, _ in txn:
323 if int(to) > topological_ordering:
324 return False
325 elif int(to) == topological_ordering and int(so) >= stream_ordering:
326 return False
327
328 self._simple_delete_txn(
329 txn,
330 table="receipts_linearized",
331 keyvalues={
332 "room_id": room_id,
333 "receipt_type": receipt_type,
334 "user_id": user_id,
335 }
336 )
337
338 self._simple_insert_txn(
339 txn,
340 table="receipts_linearized",
341 values={
342 "stream_id": stream_id,
343 "room_id": room_id,
344 "receipt_type": receipt_type,
345 "user_id": user_id,
346 "event_id": event_id,
347 "data": json.dumps(data),
348 }
349 )
350
351 if receipt_type == "m.read" and topological_ordering:
352 self._remove_old_push_actions_before_txn(
353 txn,
354 room_id=room_id,
355 user_id=user_id,
356 topological_ordering=topological_ordering,
357 stream_ordering=stream_ordering,
358 )
359
360 return True
361
362 @defer.inlineCallbacks
363 def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
364 """Insert a receipt, either from local client or remote server.
365
366 Automatically does conversion between linearized and graph
367 representations.
368 """
369 if not event_ids:
370 return
371
372 if len(event_ids) == 1:
373 linearized_event_id = event_ids[0]
374 else:
375 # we need to points in graph -> linearized form.
376 # TODO: Make this better.
377 def graph_to_linear(txn):
378 query = (
379 "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
380 " SELECT max(stream_ordering) WHERE event_id IN (%s)"
381 ")"
382 ) % (",".join(["?"] * len(event_ids)))
383
384 txn.execute(query, [room_id] + event_ids)
385 rows = txn.fetchall()
386 if rows:
387 return rows[0][0]
388 else:
389 raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
390
391 linearized_event_id = yield self.runInteraction(
392 "insert_receipt_conv", graph_to_linear
393 )
394
395 stream_id_manager = self._receipts_id_gen.get_next()
396 with stream_id_manager as stream_id:
397 have_persisted = yield self.runInteraction(
398 "insert_linearized_receipt",
399 self.insert_linearized_receipt_txn,
400 room_id, receipt_type, user_id, linearized_event_id,
401 data,
402 stream_id=stream_id,
403 )
404
405 if not have_persisted:
406 defer.returnValue(None)
407
408 yield self.insert_graph_receipt(
409 room_id, receipt_type, user_id, event_ids, data
410 )
411
412 max_persisted_id = self._receipts_id_gen.get_current_token()
413
414 defer.returnValue((stream_id, max_persisted_id))
415
416 def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
417 data):
418 return self.runInteraction(
419 "insert_graph_receipt",
420 self.insert_graph_receipt_txn,
421 room_id, receipt_type, user_id, event_ids, data
422 )
423
424 def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
425 user_id, event_ids, data):
426 txn.call_after(
427 self.get_receipts_for_room.invalidate, (room_id, receipt_type)
428 )
429 txn.call_after(
430 self._invalidate_get_users_with_receipts_in_room,
431 room_id, receipt_type, user_id,
432 )
433 txn.call_after(
434 self.get_receipts_for_user.invalidate, (user_id, receipt_type)
435 )
436 # FIXME: This shouldn't invalidate the whole cache
437 txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
438
439 self._simple_delete_txn(
440 txn,
441 table="receipts_graph",
442 keyvalues={
443 "room_id": room_id,
444 "receipt_type": receipt_type,
445 "user_id": user_id,
446 }
447 )
448 self._simple_insert_txn(
449 txn,
450 table="receipts_graph",
451 values={
452 "room_id": room_id,
453 "receipt_type": receipt_type,
454 "user_id": user_id,
455 "event_ids": json.dumps(event_ids),
456 "data": json.dumps(data),
457 }
458 )
459
460272 def get_all_updated_receipts(self, last_id, current_id, limit=None):
461273 if last_id == current_id:
462274 return defer.succeed([])
478290 return self.runInteraction(
479291 "get_all_updated_receipts", get_all_updated_receipts_txn
480292 )
293
294 def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
295 user_id):
296 if receipt_type != "m.read":
297 return
298
299 # Returns an ObservableDeferred
300 res = self.get_users_with_read_receipts_in_room.cache.get(
301 room_id, None, update_metrics=False,
302 )
303
304 if res:
305 if isinstance(res, defer.Deferred) and res.called:
306 res = res.result
307 if user_id in res:
308 # We'd only be adding to the set, so no point invalidating if the
309 # user is already there
310 return
311
312 self.get_users_with_read_receipts_in_room.invalidate((room_id,))
313
314
315 class ReceiptsStore(ReceiptsWorkerStore):
316 def __init__(self, db_conn, hs):
317 # We instantiate this first as the ReceiptsWorkerStore constructor
318 # needs to be able to call get_max_receipt_stream_id
319 self._receipts_id_gen = StreamIdGenerator(
320 db_conn, "receipts_linearized", "stream_id"
321 )
322
323 super(ReceiptsStore, self).__init__(db_conn, hs)
324
325 def get_max_receipt_stream_id(self):
326 return self._receipts_id_gen.get_current_token()
327
328 def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
329 user_id, event_id, data, stream_id):
330 txn.call_after(
331 self.get_receipts_for_room.invalidate, (room_id, receipt_type)
332 )
333 txn.call_after(
334 self._invalidate_get_users_with_receipts_in_room,
335 room_id, receipt_type, user_id,
336 )
337 txn.call_after(
338 self.get_receipts_for_user.invalidate, (user_id, receipt_type)
339 )
340 # FIXME: This shouldn't invalidate the whole cache
341 txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
342
343 txn.call_after(
344 self._receipts_stream_cache.entity_has_changed,
345 room_id, stream_id
346 )
347
348 txn.call_after(
349 self.get_last_receipt_event_id_for_user.invalidate,
350 (user_id, room_id, receipt_type)
351 )
352
353 res = self._simple_select_one_txn(
354 txn,
355 table="events",
356 retcols=["topological_ordering", "stream_ordering"],
357 keyvalues={"event_id": event_id},
358 allow_none=True
359 )
360
361 topological_ordering = int(res["topological_ordering"]) if res else None
362 stream_ordering = int(res["stream_ordering"]) if res else None
363
364 # We don't want to clobber receipts for more recent events, so we
365 # have to compare orderings of existing receipts
366 sql = (
367 "SELECT topological_ordering, stream_ordering, event_id FROM events"
368 " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
369 " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
370 )
371
372 txn.execute(sql, (room_id, receipt_type, user_id))
373
374 if topological_ordering:
375 for to, so, _ in txn:
376 if int(to) > topological_ordering:
377 return False
378 elif int(to) == topological_ordering and int(so) >= stream_ordering:
379 return False
380
381 self._simple_delete_txn(
382 txn,
383 table="receipts_linearized",
384 keyvalues={
385 "room_id": room_id,
386 "receipt_type": receipt_type,
387 "user_id": user_id,
388 }
389 )
390
391 self._simple_insert_txn(
392 txn,
393 table="receipts_linearized",
394 values={
395 "stream_id": stream_id,
396 "room_id": room_id,
397 "receipt_type": receipt_type,
398 "user_id": user_id,
399 "event_id": event_id,
400 "data": json.dumps(data),
401 }
402 )
403
404 if receipt_type == "m.read" and topological_ordering:
405 self._remove_old_push_actions_before_txn(
406 txn,
407 room_id=room_id,
408 user_id=user_id,
409 topological_ordering=topological_ordering,
410 stream_ordering=stream_ordering,
411 )
412
413 return True
414
415 @defer.inlineCallbacks
416 def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
417 """Insert a receipt, either from local client or remote server.
418
419 Automatically does conversion between linearized and graph
420 representations.
421 """
422 if not event_ids:
423 return
424
425 if len(event_ids) == 1:
426 linearized_event_id = event_ids[0]
427 else:
428 # we need to points in graph -> linearized form.
429 # TODO: Make this better.
430 def graph_to_linear(txn):
431 query = (
432 "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
433 " SELECT max(stream_ordering) WHERE event_id IN (%s)"
434 ")"
435 ) % (",".join(["?"] * len(event_ids)))
436
437 txn.execute(query, [room_id] + event_ids)
438 rows = txn.fetchall()
439 if rows:
440 return rows[0][0]
441 else:
442 raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
443
444 linearized_event_id = yield self.runInteraction(
445 "insert_receipt_conv", graph_to_linear
446 )
447
448 stream_id_manager = self._receipts_id_gen.get_next()
449 with stream_id_manager as stream_id:
450 have_persisted = yield self.runInteraction(
451 "insert_linearized_receipt",
452 self.insert_linearized_receipt_txn,
453 room_id, receipt_type, user_id, linearized_event_id,
454 data,
455 stream_id=stream_id,
456 )
457
458 if not have_persisted:
459 defer.returnValue(None)
460
461 yield self.insert_graph_receipt(
462 room_id, receipt_type, user_id, event_ids, data
463 )
464
465 max_persisted_id = self._receipts_id_gen.get_current_token()
466
467 defer.returnValue((stream_id, max_persisted_id))
468
469 def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
470 data):
471 return self.runInteraction(
472 "insert_graph_receipt",
473 self.insert_graph_receipt_txn,
474 room_id, receipt_type, user_id, event_ids, data
475 )
476
477 def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
478 user_id, event_ids, data):
479 txn.call_after(
480 self.get_receipts_for_room.invalidate, (room_id, receipt_type)
481 )
482 txn.call_after(
483 self._invalidate_get_users_with_receipts_in_room,
484 room_id, receipt_type, user_id,
485 )
486 txn.call_after(
487 self.get_receipts_for_user.invalidate, (user_id, receipt_type)
488 )
489 # FIXME: This shouldn't invalidate the whole cache
490 txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
491
492 self._simple_delete_txn(
493 txn,
494 table="receipts_graph",
495 keyvalues={
496 "room_id": room_id,
497 "receipt_type": receipt_type,
498 "user_id": user_id,
499 }
500 )
501 self._simple_insert_txn(
502 txn,
503 table="receipts_graph",
504 values={
505 "room_id": room_id,
506 "receipt_type": receipt_type,
507 "user_id": user_id,
508 "event_ids": json.dumps(event_ids),
509 "data": json.dumps(data),
510 }
511 )
1818
1919 from synapse.api.errors import StoreError, Codes
2020 from synapse.storage import background_updates
21 from synapse.storage._base import SQLBaseStore
2122 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
2223
2324
24 class RegistrationStore(background_updates.BackgroundUpdateStore):
25
26 def __init__(self, hs):
27 super(RegistrationStore, self).__init__(hs)
25 class RegistrationWorkerStore(SQLBaseStore):
26 @cached()
27 def get_user_by_id(self, user_id):
28 return self._simple_select_one(
29 table="users",
30 keyvalues={
31 "name": user_id,
32 },
33 retcols=["name", "password_hash", "is_guest"],
34 allow_none=True,
35 desc="get_user_by_id",
36 )
37
38 @cached()
39 def get_user_by_access_token(self, token):
40 """Get a user from the given access token.
41
42 Args:
43 token (str): The access token of a user.
44 Returns:
45 defer.Deferred: None, if the token did not match, otherwise dict
46 including the keys `name`, `is_guest`, `device_id`, `token_id`.
47 """
48 return self.runInteraction(
49 "get_user_by_access_token",
50 self._query_for_auth,
51 token
52 )
53
54 @defer.inlineCallbacks
55 def is_server_admin(self, user):
56 res = yield self._simple_select_one_onecol(
57 table="users",
58 keyvalues={"name": user.to_string()},
59 retcol="admin",
60 allow_none=True,
61 desc="is_server_admin",
62 )
63
64 defer.returnValue(res if res else False)
65
66 def _query_for_auth(self, txn, token):
67 sql = (
68 "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
69 " access_tokens.device_id"
70 " FROM users"
71 " INNER JOIN access_tokens on users.name = access_tokens.user_id"
72 " WHERE token = ?"
73 )
74
75 txn.execute(sql, (token,))
76 rows = self.cursor_to_dict(txn)
77 if rows:
78 return rows[0]
79
80 return None
81
82
83 class RegistrationStore(RegistrationWorkerStore,
84 background_updates.BackgroundUpdateStore):
85
86 def __init__(self, db_conn, hs):
87 super(RegistrationStore, self).__init__(db_conn, hs)
2888
2989 self.clock = hs.get_clock()
3090
3595 columns=["user_id", "device_id"],
3696 )
3797
38 self.register_background_index_update(
39 "refresh_tokens_device_index",
40 index_name="refresh_tokens_device_id",
41 table="refresh_tokens",
42 columns=["user_id", "device_id"],
43 )
98 # we no longer use refresh tokens, but it's possible that some people
99 # might have a background update queued to build this index. Just
100 # clear the background update.
101 self.register_noop_background_update("refresh_tokens_device_index")
44102
45103 @defer.inlineCallbacks
46104 def add_access_token_to_user(self, user_id, token, device_id=None):
176234 )
177235
178236 if create_profile_with_localpart:
237 # set a default displayname serverside to avoid ugly race
238 # between auto-joins and clients trying to set displaynames
179239 txn.execute(
180 "INSERT INTO profiles(user_id) VALUES (?)",
181 (create_profile_with_localpart,)
240 "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
241 (create_profile_with_localpart, create_profile_with_localpart)
182242 )
183243
184244 self._invalidate_cache_and_stream(
185245 txn, self.get_user_by_id, (user_id,)
186246 )
187247 txn.call_after(self.is_guest.invalidate, (user_id,))
188
189 @cached()
190 def get_user_by_id(self, user_id):
191 return self._simple_select_one(
192 table="users",
193 keyvalues={
194 "name": user_id,
195 },
196 retcols=["name", "password_hash", "is_guest"],
197 allow_none=True,
198 desc="get_user_by_id",
199 )
200248
201249 def get_users_by_id_case_insensitive(self, user_id):
202250 """Gets users that match user_id case insensitively.
235283 "user_set_password_hash", user_set_password_hash_txn
236284 )
237285
238 @defer.inlineCallbacks
239286 def user_delete_access_tokens(self, user_id, except_token_id=None,
240 device_id=None,
241 delete_refresh_tokens=False):
242 """
243 Invalidate access/refresh tokens belonging to a user
287 device_id=None):
288 """
289 Invalidate access tokens belonging to a user
244290
245291 Args:
246292 user_id (str): ID of user the tokens belong to
249295 device_id (str|None): ID of device the tokens are associated with.
250296 If None, tokens associated with any device (or no device) will
251297 be deleted
252 delete_refresh_tokens (bool): True to delete refresh tokens as
253 well as access tokens.
254298 Returns:
255 defer.Deferred:
299 defer.Deferred[list[str, int, str|None, int]]: a list of
300 (token, token id, device id) for each of the deleted tokens
256301 """
257302 def f(txn):
258303 keyvalues = {
261306 if device_id is not None:
262307 keyvalues["device_id"] = device_id
263308
264 if delete_refresh_tokens:
265 self._simple_delete_txn(
266 txn,
267 table="refresh_tokens",
268 keyvalues=keyvalues,
269 )
270
271309 items = keyvalues.items()
272310 where_clause = " AND ".join(k + " = ?" for k, _ in items)
273311 values = [v for _, v in items]
276314 values.append(except_token_id)
277315
278316 txn.execute(
279 "SELECT token FROM access_tokens WHERE %s" % where_clause,
317 "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
280318 values
281319 )
282 rows = self.cursor_to_dict(txn)
283
284 for row in rows:
320 tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
321
322 for token, _, _ in tokens_and_devices:
285323 self._invalidate_cache_and_stream(
286 txn, self.get_user_by_access_token, (row["token"],)
324 txn, self.get_user_by_access_token, (token,)
287325 )
288326
289327 txn.execute(
291329 values
292330 )
293331
294 yield self.runInteraction(
332 return tokens_and_devices
333
334 return self.runInteraction(
295335 "user_delete_access_tokens", f,
296336 )
297337
311351
312352 return self.runInteraction("delete_access_token", f)
313353
314 @cached()
315 def get_user_by_access_token(self, token):
316 """Get a user from the given access token.
317
318 Args:
319 token (str): The access token of a user.
320 Returns:
321 defer.Deferred: None, if the token did not match, otherwise dict
322 including the keys `name`, `is_guest`, `device_id`, `token_id`.
323 """
324 return self.runInteraction(
325 "get_user_by_access_token",
326 self._query_for_auth,
327 token
328 )
329
330 @defer.inlineCallbacks
331 def is_server_admin(self, user):
332 res = yield self._simple_select_one_onecol(
333 table="users",
334 keyvalues={"name": user.to_string()},
335 retcol="admin",
336 allow_none=True,
337 desc="is_server_admin",
338 )
339
340 defer.returnValue(res if res else False)
341
342354 @cachedInlineCallbacks()
343355 def is_guest(self, user_id):
344356 res = yield self._simple_select_one_onecol(
350362 )
351363
352364 defer.returnValue(res if res else False)
353
354 def _query_for_auth(self, txn, token):
355 sql = (
356 "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
357 " access_tokens.device_id"
358 " FROM users"
359 " INNER JOIN access_tokens on users.name = access_tokens.user_id"
360 " WHERE token = ?"
361 )
362
363 txn.execute(sql, (token,))
364 rows = self.cursor_to_dict(txn)
365 if rows:
366 return rows[0]
367
368 return None
369365
370366 @defer.inlineCallbacks
371367 def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
1515 from twisted.internet import defer
1616
1717 from synapse.api.errors import StoreError
18 from synapse.storage._base import SQLBaseStore
19 from synapse.storage.search import SearchStore
1820 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
19
20 from ._base import SQLBaseStore
21 from .engines import PostgresEngine, Sqlite3Engine
2221
2322 import collections
2423 import logging
25 import ujson as json
24 import simplejson as json
2625 import re
2726
2827 logger = logging.getLogger(__name__)
3938 )
4039
4140
42 class RoomStore(SQLBaseStore):
41 class RoomWorkerStore(SQLBaseStore):
42 def get_public_room_ids(self):
43 return self._simple_select_onecol(
44 table="rooms",
45 keyvalues={
46 "is_public": True,
47 },
48 retcol="room_id",
49 desc="get_public_room_ids",
50 )
51
52 @cached(num_args=2, max_entries=100)
53 def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
54 """Get pulbic rooms for a particular list, or across all lists.
55
56 Args:
57 stream_id (int)
58 network_tuple (ThirdPartyInstanceID): The list to use (None, None)
59 means the main list, None means all lsits.
60 """
61 return self.runInteraction(
62 "get_public_room_ids_at_stream_id",
63 self.get_public_room_ids_at_stream_id_txn,
64 stream_id, network_tuple=network_tuple
65 )
66
67 def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
68 network_tuple):
69 return {
70 rm
71 for rm, vis in self.get_published_at_stream_id_txn(
72 txn, stream_id, network_tuple=network_tuple
73 ).items()
74 if vis
75 }
76
77 def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
78 if network_tuple:
79 # We want to get from a particular list. No aggregation required.
80
81 sql = ("""
82 SELECT room_id, visibility FROM public_room_list_stream
83 INNER JOIN (
84 SELECT room_id, max(stream_id) AS stream_id
85 FROM public_room_list_stream
86 WHERE stream_id <= ? %s
87 GROUP BY room_id
88 ) grouped USING (room_id, stream_id)
89 """)
90
91 if network_tuple.appservice_id is not None:
92 txn.execute(
93 sql % ("AND appservice_id = ? AND network_id = ?",),
94 (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
95 )
96 else:
97 txn.execute(
98 sql % ("AND appservice_id IS NULL",),
99 (stream_id,)
100 )
101 return dict(txn)
102 else:
103 # We want to get from all lists, so we need to aggregate the results
104
105 logger.info("Executing full list")
106
107 sql = ("""
108 SELECT room_id, visibility
109 FROM public_room_list_stream
110 INNER JOIN (
111 SELECT
112 room_id, max(stream_id) AS stream_id, appservice_id,
113 network_id
114 FROM public_room_list_stream
115 WHERE stream_id <= ?
116 GROUP BY room_id, appservice_id, network_id
117 ) grouped USING (room_id, stream_id)
118 """)
119
120 txn.execute(
121 sql,
122 (stream_id,)
123 )
124
125 results = {}
126 # A room is visible if its visible on any list.
127 for room_id, visibility in txn:
128 results[room_id] = bool(visibility) or results.get(room_id, False)
129
130 return results
131
132 def get_public_room_changes(self, prev_stream_id, new_stream_id,
133 network_tuple):
134 def get_public_room_changes_txn(txn):
135 then_rooms = self.get_public_room_ids_at_stream_id_txn(
136 txn, prev_stream_id, network_tuple
137 )
138
139 now_rooms_dict = self.get_published_at_stream_id_txn(
140 txn, new_stream_id, network_tuple
141 )
142
143 now_rooms_visible = set(
144 rm for rm, vis in now_rooms_dict.items() if vis
145 )
146 now_rooms_not_visible = set(
147 rm for rm, vis in now_rooms_dict.items() if not vis
148 )
149
150 newly_visible = now_rooms_visible - then_rooms
151 newly_unpublished = now_rooms_not_visible & then_rooms
152
153 return newly_visible, newly_unpublished
154
155 return self.runInteraction(
156 "get_public_room_changes", get_public_room_changes_txn
157 )
158
159 @cached(max_entries=10000)
160 def is_room_blocked(self, room_id):
161 return self._simple_select_one_onecol(
162 table="blocked_rooms",
163 keyvalues={
164 "room_id": room_id,
165 },
166 retcol="1",
167 allow_none=True,
168 desc="is_room_blocked",
169 )
170
171
172 class RoomStore(RoomWorkerStore, SearchStore):
43173
44174 @defer.inlineCallbacks
45175 def store_room(self, room_id, room_creator_user_id, is_public):
226356 )
227357 self.hs.get_notifier().on_new_replication_data()
228358
229 def get_public_room_ids(self):
230 return self._simple_select_onecol(
231 table="rooms",
232 keyvalues={
233 "is_public": True,
234 },
235 retcol="room_id",
236 desc="get_public_room_ids",
237 )
238
239359 def get_room_count(self):
240360 """Retrieve a list of all rooms
241361 """
262382 },
263383 )
264384
265 self._store_event_search_txn(
266 txn, event, "content.topic", event.content["topic"]
385 self.store_event_search_txn(
386 txn, event, "content.topic", event.content["topic"],
267387 )
268388
269389 def _store_room_name_txn(self, txn, event):
278398 }
279399 )
280400
281 self._store_event_search_txn(
282 txn, event, "content.name", event.content["name"]
401 self.store_event_search_txn(
402 txn, event, "content.name", event.content["name"],
283403 )
284404
285405 def _store_room_message_txn(self, txn, event):
286406 if hasattr(event, "content") and "body" in event.content:
287 self._store_event_search_txn(
288 txn, event, "content.body", event.content["body"]
407 self.store_event_search_txn(
408 txn, event, "content.body", event.content["body"],
289409 )
290410
291411 def _store_history_visibility_txn(self, txn, event):
306426 event.room_id,
307427 event.content[key]
308428 ))
309
310 def _store_event_search_txn(self, txn, event, key, value):
311 if isinstance(self.database_engine, PostgresEngine):
312 sql = (
313 "INSERT INTO event_search"
314 " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
315 " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
316 )
317 txn.execute(
318 sql,
319 (
320 event.event_id, event.room_id, key, value,
321 event.internal_metadata.stream_ordering,
322 event.origin_server_ts,
323 )
324 )
325 elif isinstance(self.database_engine, Sqlite3Engine):
326 sql = (
327 "INSERT INTO event_search (event_id, room_id, key, value)"
328 " VALUES (?,?,?,?)"
329 )
330 txn.execute(sql, (event.event_id, event.room_id, key, value,))
331 else:
332 # This should be unreachable.
333 raise Exception("Unrecognized database engine")
334429
335430 def add_event_report(self, room_id, event_id, user_id, reason, content,
336431 received_ts):
352447 def get_current_public_room_stream_id(self):
353448 return self._public_room_id_gen.get_current_token()
354449
355 @cached(num_args=2, max_entries=100)
356 def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
357 """Get pulbic rooms for a particular list, or across all lists.
358
359 Args:
360 stream_id (int)
361 network_tuple (ThirdPartyInstanceID): The list to use (None, None)
362 means the main list, None means all lsits.
363 """
364 return self.runInteraction(
365 "get_public_room_ids_at_stream_id",
366 self.get_public_room_ids_at_stream_id_txn,
367 stream_id, network_tuple=network_tuple
368 )
369
370 def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
371 network_tuple):
372 return {
373 rm
374 for rm, vis in self.get_published_at_stream_id_txn(
375 txn, stream_id, network_tuple=network_tuple
376 ).items()
377 if vis
378 }
379
380 def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
381 if network_tuple:
382 # We want to get from a particular list. No aggregation required.
383
384 sql = ("""
385 SELECT room_id, visibility FROM public_room_list_stream
386 INNER JOIN (
387 SELECT room_id, max(stream_id) AS stream_id
388 FROM public_room_list_stream
389 WHERE stream_id <= ? %s
390 GROUP BY room_id
391 ) grouped USING (room_id, stream_id)
392 """)
393
394 if network_tuple.appservice_id is not None:
395 txn.execute(
396 sql % ("AND appservice_id = ? AND network_id = ?",),
397 (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
398 )
399 else:
400 txn.execute(
401 sql % ("AND appservice_id IS NULL",),
402 (stream_id,)
403 )
404 return dict(txn)
405 else:
406 # We want to get from all lists, so we need to aggregate the results
407
408 logger.info("Executing full list")
409
410 sql = ("""
411 SELECT room_id, visibility
412 FROM public_room_list_stream
413 INNER JOIN (
414 SELECT
415 room_id, max(stream_id) AS stream_id, appservice_id,
416 network_id
417 FROM public_room_list_stream
418 WHERE stream_id <= ?
419 GROUP BY room_id, appservice_id, network_id
420 ) grouped USING (room_id, stream_id)
421 """)
422
423 txn.execute(
424 sql,
425 (stream_id,)
426 )
427
428 results = {}
429 # A room is visible if its visible on any list.
430 for room_id, visibility in txn:
431 results[room_id] = bool(visibility) or results.get(room_id, False)
432
433 return results
434
435 def get_public_room_changes(self, prev_stream_id, new_stream_id,
436 network_tuple):
437 def get_public_room_changes_txn(txn):
438 then_rooms = self.get_public_room_ids_at_stream_id_txn(
439 txn, prev_stream_id, network_tuple
440 )
441
442 now_rooms_dict = self.get_published_at_stream_id_txn(
443 txn, new_stream_id, network_tuple
444 )
445
446 now_rooms_visible = set(
447 rm for rm, vis in now_rooms_dict.items() if vis
448 )
449 now_rooms_not_visible = set(
450 rm for rm, vis in now_rooms_dict.items() if not vis
451 )
452
453 newly_visible = now_rooms_visible - then_rooms
454 newly_unpublished = now_rooms_not_visible & then_rooms
455
456 return newly_visible, newly_unpublished
457
458 return self.runInteraction(
459 "get_public_room_changes", get_public_room_changes_txn
460 )
461
462450 def get_all_new_public_rooms(self, prev_id, current_id, limit):
463451 def get_all_new_public_rooms(txn):
464452 sql = ("""
508496 else:
509497 defer.returnValue(None)
510498
511 @cached(max_entries=10000)
512 def is_room_blocked(self, room_id):
513 return self._simple_select_one_onecol(
514 table="blocked_rooms",
515 keyvalues={
516 "room_id": room_id,
517 },
518 retcol="1",
519 allow_none=True,
520 desc="is_room_blocked",
521 )
522
523499 @defer.inlineCallbacks
524500 def block_room(self, room_id, user_id):
525501 yield self._simple_insert(
530506 },
531507 desc="block_room",
532508 )
533 self.is_room_blocked.invalidate((room_id,))
509 yield self.runInteraction(
510 "block_room_invalidation",
511 self._invalidate_cache_and_stream,
512 self.is_room_blocked, (room_id,),
513 )
514
515 def get_media_mxcs_in_room(self, room_id):
516 """Retrieves all the local and remote media MXC URIs in a given room
517
518 Args:
519 room_id (str)
520
521 Returns:
522 The local and remote media as a lists of tuples where the key is
523 the hostname and the value is the media ID.
524 """
525 def _get_media_mxcs_in_room_txn(txn):
526 local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
527 local_media_mxcs = []
528 remote_media_mxcs = []
529
530 # Convert the IDs to MXC URIs
531 for media_id in local_mxcs:
532 local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
533 for hostname, media_id in remote_mxcs:
534 remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
535
536 return local_media_mxcs, remote_media_mxcs
537 return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
534538
535539 def quarantine_media_ids_in_room(self, room_id, quarantined_by):
536540 """For a room loops through all events with media and quarantines
537541 the associated media
538542 """
539 def _get_media_ids_in_room(txn):
540 mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
541
542 next_token = self.get_current_events_token() + 1
543
543 def _quarantine_media_in_room_txn(txn):
544 local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
544545 total_media_quarantined = 0
545546
546 while next_token:
547 sql = """
548 SELECT stream_ordering, content FROM events
549 WHERE room_id = ?
550 AND stream_ordering < ?
551 AND contains_url = ? AND outlier = ?
552 ORDER BY stream_ordering DESC
553 LIMIT ?
547 # Now update all the tables to set the quarantined_by flag
548
549 txn.executemany("""
550 UPDATE local_media_repository
551 SET quarantined_by = ?
552 WHERE media_id = ?
553 """, ((quarantined_by, media_id) for media_id in local_mxcs))
554
555 txn.executemany(
554556 """
555 txn.execute(sql, (room_id, next_token, True, False, 100))
556
557 next_token = None
558 local_media_mxcs = []
559 remote_media_mxcs = []
560 for stream_ordering, content_json in txn:
561 next_token = stream_ordering
562 content = json.loads(content_json)
563
564 content_url = content.get("url")
565 thumbnail_url = content.get("info", {}).get("thumbnail_url")
566
567 for url in (content_url, thumbnail_url):
568 if not url:
569 continue
570 matches = mxc_re.match(url)
571 if matches:
572 hostname = matches.group(1)
573 media_id = matches.group(2)
574 if hostname == self.hostname:
575 local_media_mxcs.append(media_id)
576 else:
577 remote_media_mxcs.append((hostname, media_id))
578
579 # Now update all the tables to set the quarantined_by flag
580
581 txn.executemany("""
582 UPDATE local_media_repository
557 UPDATE remote_media_cache
583558 SET quarantined_by = ?
584 WHERE media_id = ?
585 """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
586
587 txn.executemany(
588 """
589 UPDATE remote_media_cache
590 SET quarantined_by = ?
591 WHERE media_origin AND media_id = ?
592 """,
593 (
594 (quarantined_by, origin, media_id)
595 for origin, media_id in remote_media_mxcs
596 )
597 )
598
599 total_media_quarantined += len(local_media_mxcs)
600 total_media_quarantined += len(remote_media_mxcs)
559 WHERE media_origin = ? AND media_id = ?
560 """,
561 (
562 (quarantined_by, origin, media_id)
563 for origin, media_id in remote_mxcs
564 )
565 )
566
567 total_media_quarantined += len(local_mxcs)
568 total_media_quarantined += len(remote_mxcs)
601569
602570 return total_media_quarantined
603571
604 return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
572 return self.runInteraction(
573 "quarantine_media_in_room",
574 _quarantine_media_in_room_txn,
575 )
576
577 def _get_media_mxcs_in_room_txn(self, txn, room_id):
578 """Retrieves all the local and remote media MXC URIs in a given room
579
580 Args:
581 txn (cursor)
582 room_id (str)
583
584 Returns:
585 The local and remote media as a lists of tuples where the key is
586 the hostname and the value is the media ID.
587 """
588 mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
589
590 next_token = self.get_current_events_token() + 1
591 local_media_mxcs = []
592 remote_media_mxcs = []
593
594 while next_token:
595 sql = """
596 SELECT stream_ordering, content FROM events
597 WHERE room_id = ?
598 AND stream_ordering < ?
599 AND contains_url = ? AND outlier = ?
600 ORDER BY stream_ordering DESC
601 LIMIT ?
602 """
603 txn.execute(sql, (room_id, next_token, True, False, 100))
604
605 next_token = None
606 for stream_ordering, content_json in txn:
607 next_token = stream_ordering
608 content = json.loads(content_json)
609
610 content_url = content.get("url")
611 thumbnail_url = content.get("info", {}).get("thumbnail_url")
612
613 for url in (content_url, thumbnail_url):
614 if not url:
615 continue
616 matches = mxc_re.match(url)
617 if matches:
618 hostname = matches.group(1)
619 media_id = matches.group(2)
620 if hostname == self.hostname:
621 local_media_mxcs.append(media_id)
622 else:
623 remote_media_mxcs.append((hostname, media_id))
624
625 return local_media_mxcs, remote_media_mxcs
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1617
1718 from collections import namedtuple
1819
19 from ._base import SQLBaseStore
20 from synapse.storage.events import EventsWorkerStore
2021 from synapse.util.async import Linearizer
2122 from synapse.util.caches import intern_string
2223 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
2627 from synapse.types import get_domain_from_id
2728
2829 import logging
29 import ujson as json
30 import simplejson as json
3031
3132 logger = logging.getLogger(__name__)
3233
3435 RoomsForUser = namedtuple(
3536 "RoomsForUser",
3637 ("room_id", "sender", "membership", "event_id", "stream_ordering")
38 )
39
40 GetRoomsForUserWithStreamOrdering = namedtuple(
41 "_GetRoomsForUserWithStreamOrdering",
42 ("room_id", "stream_ordering",)
3743 )
3844
3945
4753 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
4854
4955
50 class RoomMemberStore(SQLBaseStore):
51 def __init__(self, hs):
52 super(RoomMemberStore, self).__init__(hs)
56 class RoomMemberWorkerStore(EventsWorkerStore):
57 @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
58 def get_hosts_in_room(self, room_id, cache_context):
59 """Returns the set of all hosts currently in the room
60 """
61 user_ids = yield self.get_users_in_room(
62 room_id, on_invalidate=cache_context.invalidate,
63 )
64 hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
65 defer.returnValue(hosts)
66
67 @cached(max_entries=100000, iterable=True)
68 def get_users_in_room(self, room_id):
69 def f(txn):
70 sql = (
71 "SELECT m.user_id FROM room_memberships as m"
72 " INNER JOIN current_state_events as c"
73 " ON m.event_id = c.event_id "
74 " AND m.room_id = c.room_id "
75 " AND m.user_id = c.state_key"
76 " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
77 )
78
79 txn.execute(sql, (room_id, Membership.JOIN,))
80 return [to_ascii(r[0]) for r in txn]
81 return self.runInteraction("get_users_in_room", f)
82
83 @cached()
84 def get_invited_rooms_for_user(self, user_id):
85 """ Get all the rooms the user is invited to
86 Args:
87 user_id (str): The user ID.
88 Returns:
89 A deferred list of RoomsForUser.
90 """
91
92 return self.get_rooms_for_user_where_membership_is(
93 user_id, [Membership.INVITE]
94 )
95
96 @defer.inlineCallbacks
97 def get_invite_for_user_in_room(self, user_id, room_id):
98 """Gets the invite for the given user and room
99
100 Args:
101 user_id (str)
102 room_id (str)
103
104 Returns:
105 Deferred: Resolves to either a RoomsForUser or None if no invite was
106 found.
107 """
108 invites = yield self.get_invited_rooms_for_user(user_id)
109 for invite in invites:
110 if invite.room_id == room_id:
111 defer.returnValue(invite)
112 defer.returnValue(None)
113
114 def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
115 """ Get all the rooms for this user where the membership for this user
116 matches one in the membership list.
117
118 Args:
119 user_id (str): The user ID.
120 membership_list (list): A list of synapse.api.constants.Membership
121 values which the user must be in.
122 Returns:
123 A list of dictionary objects, with room_id, membership and sender
124 defined.
125 """
126 if not membership_list:
127 return defer.succeed(None)
128
129 return self.runInteraction(
130 "get_rooms_for_user_where_membership_is",
131 self._get_rooms_for_user_where_membership_is_txn,
132 user_id, membership_list
133 )
134
135 def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
136 membership_list):
137
138 do_invite = Membership.INVITE in membership_list
139 membership_list = [m for m in membership_list if m != Membership.INVITE]
140
141 results = []
142 if membership_list:
143 where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
144 " OR ".join(["membership = ?" for _ in membership_list]),
145 )
146
147 args = [user_id]
148 args.extend(membership_list)
149
150 sql = (
151 "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
152 " FROM current_state_events as c"
153 " INNER JOIN room_memberships as m"
154 " ON m.event_id = c.event_id"
155 " INNER JOIN events as e"
156 " ON e.event_id = c.event_id"
157 " AND m.room_id = c.room_id"
158 " AND m.user_id = c.state_key"
159 " WHERE c.type = 'm.room.member' AND %s"
160 ) % (where_clause,)
161
162 txn.execute(sql, args)
163 results = [
164 RoomsForUser(**r) for r in self.cursor_to_dict(txn)
165 ]
166
167 if do_invite:
168 sql = (
169 "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
170 " FROM local_invites as i"
171 " INNER JOIN events as e USING (event_id)"
172 " WHERE invitee = ? AND locally_rejected is NULL"
173 " AND replaced_by is NULL"
174 )
175
176 txn.execute(sql, (user_id,))
177 results.extend(RoomsForUser(
178 room_id=r["room_id"],
179 sender=r["inviter"],
180 event_id=r["event_id"],
181 stream_ordering=r["stream_ordering"],
182 membership=Membership.INVITE,
183 ) for r in self.cursor_to_dict(txn))
184
185 return results
186
187 @cachedInlineCallbacks(max_entries=500000, iterable=True)
188 def get_rooms_for_user_with_stream_ordering(self, user_id):
189 """Returns a set of room_ids the user is currently joined to
190
191 Args:
192 user_id (str)
193
194 Returns:
195 Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
196 the rooms the user is in currently, along with the stream ordering
197 of the most recent join for that user and room.
198 """
199 rooms = yield self.get_rooms_for_user_where_membership_is(
200 user_id, membership_list=[Membership.JOIN],
201 )
202 defer.returnValue(frozenset(
203 GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
204 for r in rooms
205 ))
206
207 @defer.inlineCallbacks
208 def get_rooms_for_user(self, user_id, on_invalidate=None):
209 """Returns a set of room_ids the user is currently joined to
210 """
211 rooms = yield self.get_rooms_for_user_with_stream_ordering(
212 user_id, on_invalidate=on_invalidate,
213 )
214 defer.returnValue(frozenset(r.room_id for r in rooms))
215
216 @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
217 def get_users_who_share_room_with_user(self, user_id, cache_context):
218 """Returns the set of users who share a room with `user_id`
219 """
220 room_ids = yield self.get_rooms_for_user(
221 user_id, on_invalidate=cache_context.invalidate,
222 )
223
224 user_who_share_room = set()
225 for room_id in room_ids:
226 user_ids = yield self.get_users_in_room(
227 room_id, on_invalidate=cache_context.invalidate,
228 )
229 user_who_share_room.update(user_ids)
230
231 defer.returnValue(user_who_share_room)
232
233 def get_joined_users_from_context(self, event, context):
234 state_group = context.state_group
235 if not state_group:
236 # If state_group is None it means it has yet to be assigned a
237 # state group, i.e. we need to make sure that calls with a state_group
238 # of None don't hit previous cached calls with a None state_group.
239 # To do this we set the state_group to a new object as object() != object()
240 state_group = object()
241
242 return self._get_joined_users_from_context(
243 event.room_id, state_group, context.current_state_ids,
244 event=event,
245 context=context,
246 )
247
248 def get_joined_users_from_state(self, room_id, state_entry):
249 state_group = state_entry.state_group
250 if not state_group:
251 # If state_group is None it means it has yet to be assigned a
252 # state group, i.e. we need to make sure that calls with a state_group
253 # of None don't hit previous cached calls with a None state_group.
254 # To do this we set the state_group to a new object as object() != object()
255 state_group = object()
256
257 return self._get_joined_users_from_context(
258 room_id, state_group, state_entry.state, context=state_entry,
259 )
260
261 @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
262 max_entries=100000)
263 def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
264 cache_context, event=None, context=None):
265 # We don't use `state_group`, it's there so that we can cache based
266 # on it. However, it's important that it's never None, since two current_states
267 # with a state_group of None are likely to be different.
268 # See bulk_get_push_rules_for_room for how we work around this.
269 assert state_group is not None
270
271 users_in_room = {}
272 member_event_ids = [
273 e_id
274 for key, e_id in current_state_ids.iteritems()
275 if key[0] == EventTypes.Member
276 ]
277
278 if context is not None:
279 # If we have a context with a delta from a previous state group,
280 # check if we also have the result from the previous group in cache.
281 # If we do then we can reuse that result and simply update it with
282 # any membership changes in `delta_ids`
283 if context.prev_group and context.delta_ids:
284 prev_res = self._get_joined_users_from_context.cache.get(
285 (room_id, context.prev_group), None
286 )
287 if prev_res and isinstance(prev_res, dict):
288 users_in_room = dict(prev_res)
289 member_event_ids = [
290 e_id
291 for key, e_id in context.delta_ids.iteritems()
292 if key[0] == EventTypes.Member
293 ]
294 for etype, state_key in context.delta_ids:
295 users_in_room.pop(state_key, None)
296
297 # We check if we have any of the member event ids in the event cache
298 # before we ask the DB
299
300 # We don't update the event cache hit ratio as it completely throws off
301 # the hit ratio counts. After all, we don't populate the cache if we
302 # miss it here
303 event_map = self._get_events_from_cache(
304 member_event_ids,
305 allow_rejected=False,
306 update_metrics=False,
307 )
308
309 missing_member_event_ids = []
310 for event_id in member_event_ids:
311 ev_entry = event_map.get(event_id)
312 if ev_entry:
313 if ev_entry.event.membership == Membership.JOIN:
314 users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
315 display_name=to_ascii(
316 ev_entry.event.content.get("displayname", None)
317 ),
318 avatar_url=to_ascii(
319 ev_entry.event.content.get("avatar_url", None)
320 ),
321 )
322 else:
323 missing_member_event_ids.append(event_id)
324
325 if missing_member_event_ids:
326 rows = yield self._simple_select_many_batch(
327 table="room_memberships",
328 column="event_id",
329 iterable=missing_member_event_ids,
330 retcols=('user_id', 'display_name', 'avatar_url',),
331 keyvalues={
332 "membership": Membership.JOIN,
333 },
334 batch_size=500,
335 desc="_get_joined_users_from_context",
336 )
337
338 users_in_room.update({
339 to_ascii(row["user_id"]): ProfileInfo(
340 avatar_url=to_ascii(row["avatar_url"]),
341 display_name=to_ascii(row["display_name"]),
342 )
343 for row in rows
344 })
345
346 if event is not None and event.type == EventTypes.Member:
347 if event.membership == Membership.JOIN:
348 if event.event_id in member_event_ids:
349 users_in_room[to_ascii(event.state_key)] = ProfileInfo(
350 display_name=to_ascii(event.content.get("displayname", None)),
351 avatar_url=to_ascii(event.content.get("avatar_url", None)),
352 )
353
354 defer.returnValue(users_in_room)
355
356 @cachedInlineCallbacks(max_entries=10000)
357 def is_host_joined(self, room_id, host):
358 if '%' in host or '_' in host:
359 raise Exception("Invalid host name")
360
361 sql = """
362 SELECT state_key FROM current_state_events AS c
363 INNER JOIN room_memberships USING (event_id)
364 WHERE membership = 'join'
365 AND type = 'm.room.member'
366 AND c.room_id = ?
367 AND state_key LIKE ?
368 LIMIT 1
369 """
370
371 # We do need to be careful to ensure that host doesn't have any wild cards
372 # in it, but we checked above for known ones and we'll check below that
373 # the returned user actually has the correct domain.
374 like_clause = "%:" + host
375
376 rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
377
378 if not rows:
379 defer.returnValue(False)
380
381 user_id = rows[0][0]
382 if get_domain_from_id(user_id) != host:
383 # This can only happen if the host name has something funky in it
384 raise Exception("Invalid host name")
385
386 defer.returnValue(True)
387
388 @cachedInlineCallbacks()
389 def was_host_joined(self, room_id, host):
390 """Check whether the server is or ever was in the room.
391
392 Args:
393 room_id (str)
394 host (str)
395
396 Returns:
397 Deferred: Resolves to True if the host is/was in the room, otherwise
398 False.
399 """
400 if '%' in host or '_' in host:
401 raise Exception("Invalid host name")
402
403 sql = """
404 SELECT user_id FROM room_memberships
405 WHERE room_id = ?
406 AND user_id LIKE ?
407 AND membership = 'join'
408 LIMIT 1
409 """
410
411 # We do need to be careful to ensure that host doesn't have any wild cards
412 # in it, but we checked above for known ones and we'll check below that
413 # the returned user actually has the correct domain.
414 like_clause = "%:" + host
415
416 rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
417
418 if not rows:
419 defer.returnValue(False)
420
421 user_id = rows[0][0]
422 if get_domain_from_id(user_id) != host:
423 # This can only happen if the host name has something funky in it
424 raise Exception("Invalid host name")
425
426 defer.returnValue(True)
427
428 def get_joined_hosts(self, room_id, state_entry):
429 state_group = state_entry.state_group
430 if not state_group:
431 # If state_group is None it means it has yet to be assigned a
432 # state group, i.e. we need to make sure that calls with a state_group
433 # of None don't hit previous cached calls with a None state_group.
434 # To do this we set the state_group to a new object as object() != object()
435 state_group = object()
436
437 return self._get_joined_hosts(
438 room_id, state_group, state_entry.state, state_entry=state_entry,
439 )
440
441 @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
442 # @defer.inlineCallbacks
443 def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
444 # We don't use `state_group`, its there so that we can cache based
445 # on it. However, its important that its never None, since two current_state's
446 # with a state_group of None are likely to be different.
447 # See bulk_get_push_rules_for_room for how we work around this.
448 assert state_group is not None
449
450 cache = self._get_joined_hosts_cache(room_id)
451 joined_hosts = yield cache.get_destinations(state_entry)
452
453 defer.returnValue(joined_hosts)
454
455 @cached(max_entries=10000, iterable=True)
456 def _get_joined_hosts_cache(self, room_id):
457 return _JoinedHostsCache(self, room_id)
458
459 @cached()
460 def who_forgot_in_room(self, room_id):
461 return self._simple_select_list(
462 table="room_memberships",
463 retcols=("user_id", "event_id"),
464 keyvalues={
465 "room_id": room_id,
466 "forgotten": 1,
467 },
468 desc="who_forgot"
469 )
470
471
472 class RoomMemberStore(RoomMemberWorkerStore):
473 def __init__(self, db_conn, hs):
474 super(RoomMemberStore, self).__init__(db_conn, hs)
53475 self.register_background_update_handler(
54476 _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
55477 )
138560 with self._stream_id_gen.get_next() as stream_ordering:
139561 yield self.runInteraction("locally_reject_invite", f, stream_ordering)
140562
141 @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
142 def get_hosts_in_room(self, room_id, cache_context):
143 """Returns the set of all hosts currently in the room
144 """
145 user_ids = yield self.get_users_in_room(
146 room_id, on_invalidate=cache_context.invalidate,
147 )
148 hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
149 defer.returnValue(hosts)
150
151 @cached(max_entries=100000, iterable=True)
152 def get_users_in_room(self, room_id):
153 def f(txn):
154 sql = (
155 "SELECT m.user_id FROM room_memberships as m"
156 " INNER JOIN current_state_events as c"
157 " ON m.event_id = c.event_id "
158 " AND m.room_id = c.room_id "
159 " AND m.user_id = c.state_key"
160 " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
161 )
162
163 txn.execute(sql, (room_id, Membership.JOIN,))
164 return [to_ascii(r[0]) for r in txn]
165 return self.runInteraction("get_users_in_room", f)
166
167 @cached()
168 def get_invited_rooms_for_user(self, user_id):
169 """ Get all the rooms the user is invited to
170 Args:
171 user_id (str): The user ID.
172 Returns:
173 A deferred list of RoomsForUser.
174 """
175
176 return self.get_rooms_for_user_where_membership_is(
177 user_id, [Membership.INVITE]
178 )
179
180 @defer.inlineCallbacks
181 def get_invite_for_user_in_room(self, user_id, room_id):
182 """Gets the invite for the given user and room
183
184 Args:
185 user_id (str)
186 room_id (str)
187
188 Returns:
189 Deferred: Resolves to either a RoomsForUser or None if no invite was
190 found.
191 """
192 invites = yield self.get_invited_rooms_for_user(user_id)
193 for invite in invites:
194 if invite.room_id == room_id:
195 defer.returnValue(invite)
196 defer.returnValue(None)
197
198 def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
199 """ Get all the rooms for this user where the membership for this user
200 matches one in the membership list.
201
202 Args:
203 user_id (str): The user ID.
204 membership_list (list): A list of synapse.api.constants.Membership
205 values which the user must be in.
206 Returns:
207 A list of dictionary objects, with room_id, membership and sender
208 defined.
209 """
210 if not membership_list:
211 return defer.succeed(None)
212
213 return self.runInteraction(
214 "get_rooms_for_user_where_membership_is",
215 self._get_rooms_for_user_where_membership_is_txn,
216 user_id, membership_list
217 )
218
219 def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
220 membership_list):
221
222 do_invite = Membership.INVITE in membership_list
223 membership_list = [m for m in membership_list if m != Membership.INVITE]
224
225 results = []
226 if membership_list:
227 where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
228 " OR ".join(["membership = ?" for _ in membership_list]),
229 )
230
231 args = [user_id]
232 args.extend(membership_list)
233
234 sql = (
235 "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
236 " FROM current_state_events as c"
237 " INNER JOIN room_memberships as m"
238 " ON m.event_id = c.event_id"
239 " INNER JOIN events as e"
240 " ON e.event_id = c.event_id"
241 " AND m.room_id = c.room_id"
242 " AND m.user_id = c.state_key"
243 " WHERE c.type = 'm.room.member' AND %s"
244 ) % (where_clause,)
245
246 txn.execute(sql, args)
247 results = [
248 RoomsForUser(**r) for r in self.cursor_to_dict(txn)
249 ]
250
251 if do_invite:
252 sql = (
253 "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
254 " FROM local_invites as i"
255 " INNER JOIN events as e USING (event_id)"
256 " WHERE invitee = ? AND locally_rejected is NULL"
257 " AND replaced_by is NULL"
258 )
259
260 txn.execute(sql, (user_id,))
261 results.extend(RoomsForUser(
262 room_id=r["room_id"],
263 sender=r["inviter"],
264 event_id=r["event_id"],
265 stream_ordering=r["stream_ordering"],
266 membership=Membership.INVITE,
267 ) for r in self.cursor_to_dict(txn))
268
269 return results
270
271 @cachedInlineCallbacks(max_entries=500000, iterable=True)
272 def get_rooms_for_user(self, user_id):
273 """Returns a set of room_ids the user is currently joined to
274 """
275 rooms = yield self.get_rooms_for_user_where_membership_is(
276 user_id, membership_list=[Membership.JOIN],
277 )
278 defer.returnValue(frozenset(r.room_id for r in rooms))
279
280 @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
281 def get_users_who_share_room_with_user(self, user_id, cache_context):
282 """Returns the set of users who share a room with `user_id`
283 """
284 room_ids = yield self.get_rooms_for_user(
285 user_id, on_invalidate=cache_context.invalidate,
286 )
287
288 user_who_share_room = set()
289 for room_id in room_ids:
290 user_ids = yield self.get_users_in_room(
291 room_id, on_invalidate=cache_context.invalidate,
292 )
293 user_who_share_room.update(user_ids)
294
295 defer.returnValue(user_who_share_room)
296
297563 def forget(self, user_id, room_id):
298564 """Indicate that user_id wishes to discard history for room_id."""
299565 def f(txn):
365631 forgot = yield self.runInteraction("did_forget_membership_at", f)
366632 defer.returnValue(forgot == 1)
367633
368 @cached()
369 def who_forgot_in_room(self, room_id):
370 return self._simple_select_list(
371 table="room_memberships",
372 retcols=("user_id", "event_id"),
373 keyvalues={
374 "room_id": room_id,
375 "forgotten": 1,
376 },
377 desc="who_forgot"
378 )
379
380 def get_joined_users_from_context(self, event, context):
381 state_group = context.state_group
382 if not state_group:
383 # If state_group is None it means it has yet to be assigned a
384 # state group, i.e. we need to make sure that calls with a state_group
385 # of None don't hit previous cached calls with a None state_group.
386 # To do this we set the state_group to a new object as object() != object()
387 state_group = object()
388
389 return self._get_joined_users_from_context(
390 event.room_id, state_group, context.current_state_ids,
391 event=event,
392 context=context,
393 )
394
395 def get_joined_users_from_state(self, room_id, state_entry):
396 state_group = state_entry.state_group
397 if not state_group:
398 # If state_group is None it means it has yet to be assigned a
399 # state group, i.e. we need to make sure that calls with a state_group
400 # of None don't hit previous cached calls with a None state_group.
401 # To do this we set the state_group to a new object as object() != object()
402 state_group = object()
403
404 return self._get_joined_users_from_context(
405 room_id, state_group, state_entry.state, context=state_entry,
406 )
407
408 @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
409 max_entries=100000)
410 def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
411 cache_context, event=None, context=None):
412 # We don't use `state_group`, it's there so that we can cache based
413 # on it. However, it's important that it's never None, since two current_states
414 # with a state_group of None are likely to be different.
415 # See bulk_get_push_rules_for_room for how we work around this.
416 assert state_group is not None
417
418 users_in_room = {}
419 member_event_ids = [
420 e_id
421 for key, e_id in current_state_ids.iteritems()
422 if key[0] == EventTypes.Member
423 ]
424
425 if context is not None:
426 # If we have a context with a delta from a previous state group,
427 # check if we also have the result from the previous group in cache.
428 # If we do then we can reuse that result and simply update it with
429 # any membership changes in `delta_ids`
430 if context.prev_group and context.delta_ids:
431 prev_res = self._get_joined_users_from_context.cache.get(
432 (room_id, context.prev_group), None
433 )
434 if prev_res and isinstance(prev_res, dict):
435 users_in_room = dict(prev_res)
436 member_event_ids = [
437 e_id
438 for key, e_id in context.delta_ids.iteritems()
439 if key[0] == EventTypes.Member
440 ]
441 for etype, state_key in context.delta_ids:
442 users_in_room.pop(state_key, None)
443
444 # We check if we have any of the member event ids in the event cache
445 # before we ask the DB
446
447 # We don't update the event cache hit ratio as it completely throws off
448 # the hit ratio counts. After all, we don't populate the cache if we
449 # miss it here
450 event_map = self._get_events_from_cache(
451 member_event_ids,
452 allow_rejected=False,
453 update_metrics=False,
454 )
455
456 missing_member_event_ids = []
457 for event_id in member_event_ids:
458 ev_entry = event_map.get(event_id)
459 if ev_entry:
460 if ev_entry.event.membership == Membership.JOIN:
461 users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
462 display_name=to_ascii(
463 ev_entry.event.content.get("displayname", None)
464 ),
465 avatar_url=to_ascii(
466 ev_entry.event.content.get("avatar_url", None)
467 ),
468 )
469 else:
470 missing_member_event_ids.append(event_id)
471
472 if missing_member_event_ids:
473 rows = yield self._simple_select_many_batch(
474 table="room_memberships",
475 column="event_id",
476 iterable=missing_member_event_ids,
477 retcols=('user_id', 'display_name', 'avatar_url',),
478 keyvalues={
479 "membership": Membership.JOIN,
480 },
481 batch_size=500,
482 desc="_get_joined_users_from_context",
483 )
484
485 users_in_room.update({
486 to_ascii(row["user_id"]): ProfileInfo(
487 avatar_url=to_ascii(row["avatar_url"]),
488 display_name=to_ascii(row["display_name"]),
489 )
490 for row in rows
491 })
492
493 if event is not None and event.type == EventTypes.Member:
494 if event.membership == Membership.JOIN:
495 if event.event_id in member_event_ids:
496 users_in_room[to_ascii(event.state_key)] = ProfileInfo(
497 display_name=to_ascii(event.content.get("displayname", None)),
498 avatar_url=to_ascii(event.content.get("avatar_url", None)),
499 )
500
501 defer.returnValue(users_in_room)
502
503 @cachedInlineCallbacks(max_entries=10000)
504 def is_host_joined(self, room_id, host):
505 if '%' in host or '_' in host:
506 raise Exception("Invalid host name")
507
508 sql = """
509 SELECT state_key FROM current_state_events AS c
510 INNER JOIN room_memberships USING (event_id)
511 WHERE membership = 'join'
512 AND type = 'm.room.member'
513 AND c.room_id = ?
514 AND state_key LIKE ?
515 LIMIT 1
516 """
517
518 # We do need to be careful to ensure that host doesn't have any wild cards
519 # in it, but we checked above for known ones and we'll check below that
520 # the returned user actually has the correct domain.
521 like_clause = "%:" + host
522
523 rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
524
525 if not rows:
526 defer.returnValue(False)
527
528 user_id = rows[0][0]
529 if get_domain_from_id(user_id) != host:
530 # This can only happen if the host name has something funky in it
531 raise Exception("Invalid host name")
532
533 defer.returnValue(True)
534
535 @cachedInlineCallbacks()
536 def was_host_joined(self, room_id, host):
537 """Check whether the server is or ever was in the room.
538
539 Args:
540 room_id (str)
541 host (str)
542
543 Returns:
544 Deferred: Resolves to True if the host is/was in the room, otherwise
545 False.
546 """
547 if '%' in host or '_' in host:
548 raise Exception("Invalid host name")
549
550 sql = """
551 SELECT user_id FROM room_memberships
552 WHERE room_id = ?
553 AND user_id LIKE ?
554 AND membership = 'join'
555 LIMIT 1
556 """
557
558 # We do need to be careful to ensure that host doesn't have any wild cards
559 # in it, but we checked above for known ones and we'll check below that
560 # the returned user actually has the correct domain.
561 like_clause = "%:" + host
562
563 rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
564
565 if not rows:
566 defer.returnValue(False)
567
568 user_id = rows[0][0]
569 if get_domain_from_id(user_id) != host:
570 # This can only happen if the host name has something funky in it
571 raise Exception("Invalid host name")
572
573 defer.returnValue(True)
574
575 def get_joined_hosts(self, room_id, state_entry):
576 state_group = state_entry.state_group
577 if not state_group:
578 # If state_group is None it means it has yet to be assigned a
579 # state group, i.e. we need to make sure that calls with a state_group
580 # of None don't hit previous cached calls with a None state_group.
581 # To do this we set the state_group to a new object as object() != object()
582 state_group = object()
583
584 return self._get_joined_hosts(
585 room_id, state_group, state_entry.state, state_entry=state_entry,
586 )
587
588 @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
589 # @defer.inlineCallbacks
590 def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
591 # We don't use `state_group`, its there so that we can cache based
592 # on it. However, its important that its never None, since two current_state's
593 # with a state_group of None are likely to be different.
594 # See bulk_get_push_rules_for_room for how we work around this.
595 assert state_group is not None
596
597 cache = self._get_joined_hosts_cache(room_id)
598 joined_hosts = yield cache.get_destinations(state_entry)
599
600 defer.returnValue(joined_hosts)
601
602634 @defer.inlineCallbacks
603635 def _background_add_membership_profile(self, progress, batch_size):
604636 target_min_stream_id = progress.get(
635667 room_id = row["room_id"]
636668 try:
637669 content = json.loads(row["content"])
638 except:
670 except Exception:
639671 continue
640672
641673 display_name = content.get("displayname", None)
674706
675707 defer.returnValue(result)
676708
677 @cached(max_entries=10000, iterable=True)
678 def _get_joined_hosts_cache(self, room_id):
679 return _JoinedHostsCache(self, room_id)
680
681709
682710 class _JoinedHostsCache(object):
683711 """Cache for joined hosts in a room that is optimised to handle updates
+0
-21
synapse/storage/schema/delta/23/refresh_tokens.sql less more
0 /* Copyright 2015, 2016 OpenMarket Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE TABLE IF NOT EXISTS refresh_tokens(
16 id INTEGER PRIMARY KEY,
17 token TEXT NOT NULL,
18 user_id TEXT NOT NULL,
19 UNIQUE (token)
20 );
1616 from synapse.storage.prepare_database import get_statements
1717 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
1818
19 import ujson
19 import simplejson
2020
2121 logger = logging.getLogger(__name__)
2222
6565 "max_stream_id_exclusive": max_stream_id + 1,
6666 "rows_inserted": 0,
6767 }
68 progress_json = ujson.dumps(progress)
68 progress_json = simplejson.dumps(progress)
6969
7070 sql = (
7171 "INSERT into background_updates (update_name, progress_json)"
1515
1616 from synapse.storage.prepare_database import get_statements
1717
18 import ujson
18 import simplejson
1919
2020 logger = logging.getLogger(__name__)
2121
4444 "max_stream_id_exclusive": max_stream_id + 1,
4545 "rows_inserted": 0,
4646 }
47 progress_json = ujson.dumps(progress)
47 progress_json = simplejson.dumps(progress)
4848
4949 sql = (
5050 "INSERT into background_updates (update_name, progress_json)"
2121 # NULL indicates user was not registered by an appservice.
2222 try:
2323 cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
24 except:
24 except Exception:
2525 # Maybe we already added the column? Hope so...
2626 pass
2727
1515 from synapse.storage.prepare_database import get_statements
1616
1717 import logging
18 import ujson
18 import simplejson
1919
2020 logger = logging.getLogger(__name__)
2121
4848 "rows_inserted": 0,
4949 "have_added_indexes": False,
5050 }
51 progress_json = ujson.dumps(progress)
51 progress_json = simplejson.dumps(progress)
5252
5353 sql = (
5454 "INSERT into background_updates (update_name, progress_json)"
1414 from synapse.storage.prepare_database import get_statements
1515
1616 import logging
17 import ujson
17 import simplejson
1818
1919 logger = logging.getLogger(__name__)
2020
4343 "max_stream_id_exclusive": max_stream_id + 1,
4444 "rows_inserted": 0,
4545 }
46 progress_json = ujson.dumps(progress)
46 progress_json = simplejson.dumps(progress)
4747
4848 sql = (
4949 "INSERT into background_updates (update_name, progress_json)"
+0
-16
synapse/storage/schema/delta/33/refreshtoken_device.sql less more
0 /* Copyright 2016 OpenMarket Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
+0
-17
synapse/storage/schema/delta/33/refreshtoken_device_index.sql less more
0 /* Copyright 2016 OpenMarket Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 INSERT INTO background_updates (update_name, progress_json) VALUES
16 ('refresh_tokens_device_index', '{}');
1212 * limitations under the License.
1313 */
1414
15 INSERT into background_updates (update_name, progress_json)
16 VALUES ('event_search_postgres_gist', '{}');
15 -- We no longer do this given we back it out again in schema 47
16
17 -- INSERT into background_updates (update_name, progress_json)
18 -- VALUES ('event_search_postgres_gist', '{}');
2828 CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
2929
3030
31 -- Make sure that we popualte the table initially
31 -- Make sure that we populate the table initially
3232 UPDATE user_directory_stream_pos SET stream_id = NULL;
1212 * limitations under the License.
1313 */
1414
15 CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
15 -- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
16 -- removed and replaced with 46/local_media_repository_url_idx.sql.
17 --
18 -- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
1619
1720 -- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
1821 -- indices on expressions until 3.9.
0 /* Copyright 2017 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 /* we no longer use (or create) the refresh_tokens table */
16 DROP TABLE IF EXISTS refresh_tokens;
0 /* Copyright 2017 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- drop the unique constraint on deleted_pushers so that we can just insert
16 -- into it rather than upserting.
17
18 CREATE TABLE deleted_pushers2 (
19 stream_id BIGINT NOT NULL,
20 app_id TEXT NOT NULL,
21 pushkey TEXT NOT NULL,
22 user_id TEXT NOT NULL
23 );
24
25 INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
26 SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
27
28 DROP TABLE deleted_pushers;
29 ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
30
31 -- create the index after doing the inserts because that's more efficient.
32 -- it also means we can give it the same name as the old one without renaming.
33 CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
34
0 /* Copyright 2017 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE TABLE groups_new (
16 group_id TEXT NOT NULL,
17 name TEXT, -- the display name of the room
18 avatar_url TEXT,
19 short_description TEXT,
20 long_description TEXT,
21 is_public BOOL NOT NULL -- whether non-members can access group APIs
22 );
23
24 -- NB: awful hack to get the default to be true on postgres and 1 on sqlite
25 INSERT INTO groups_new
26 SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
27
28 DROP TABLE groups;
29 ALTER TABLE groups_new RENAME TO groups;
30
31 CREATE UNIQUE INDEX groups_idx ON groups(group_id);
0 /* Copyright 2017 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- register a background update which will recreate the
16 -- local_media_repository_url_idx index.
17 --
18 -- We do this as a bg update not because it is a particularly onerous
19 -- operation, but because we'd like it to be a partial index if possible, and
20 -- the background_index_update code will understand whether we are on
21 -- postgres or sqlite and behave accordingly.
22 INSERT INTO background_updates (update_name, progress_json) VALUES
23 ('local_media_repository_url_idx', '{}');
0 /* Copyright 2017 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- change the user_directory table to also cover global local user profiles
16 -- rather than just profiles within specific rooms.
17
18 CREATE TABLE user_directory2 (
19 user_id TEXT NOT NULL,
20 room_id TEXT,
21 display_name TEXT,
22 avatar_url TEXT
23 );
24
25 INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
26 SELECT user_id, room_id, display_name, avatar_url from user_directory;
27
28 DROP TABLE user_directory;
29 ALTER TABLE user_directory2 RENAME TO user_directory;
30
31 -- create indexes after doing the inserts because that's more efficient.
32 -- it also means we can give it the same name as the old one without renaming.
33 CREATE INDEX user_directory_room_idx ON user_directory(room_id);
34 CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
0 /* Copyright 2017 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- this is just embarassing :|
16 ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
17
18 -- this is only 300K rows on matrix.org and takes ~3s to generate the index,
19 -- so is hopefully not going to block anyone else for that long...
20 CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
21 CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
22 DROP INDEX users_in_pubic_room_room_idx;
23 DROP INDEX users_in_pubic_room_user_idx;
0 /* Copyright 2018 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
0 /* Copyright 2018 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 INSERT into background_updates (update_name, progress_json)
16 VALUES ('event_search_postgres_gin', '{}');
0 /* Copyright 2018 New Vector Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 -- Temporary staging area for push actions that have been calculated for an
16 -- event, but the event hasn't yet been persisted.
17 -- When the event is persisted the rows are moved over to the
18 -- event_push_actions table.
19 CREATE TABLE event_push_actions_staging (
20 event_id TEXT NOT NULL,
21 user_id TEXT NOT NULL,
22 actions TEXT NOT NULL,
23 notif SMALLINT NOT NULL,
24 highlight SMALLINT NOT NULL
25 );
26
27 CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
0 # Copyright 2018 New Vector Ltd
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 from synapse.storage.engines import PostgresEngine
15
16
17 def run_create(cur, database_engine, *args, **kwargs):
18 if isinstance(database_engine, PostgresEngine):
19 # if we already have some state groups, we want to start making new
20 # ones with a higher id.
21 cur.execute("SELECT max(id) FROM state_groups")
22 row = cur.fetchone()
23
24 if row[0] is None:
25 start_val = 1
26 else:
27 start_val = row[0] + 1
28
29 cur.execute(
30 "CREATE SEQUENCE state_group_id_seq START WITH %s",
31 (start_val, ),
32 )
33
34
35 def run_upgrade(*args, **kwargs):
36 pass
2424 file TEXT NOT NULL,
2525 UNIQUE(version, file)
2626 );
27
28 -- a list of schema files we have loaded on behalf of dynamic modules
29 CREATE TABLE IF NOT EXISTS applied_module_schemas(
30 module_name TEXT NOT NULL,
31 file TEXT NOT NULL,
32 UNIQUE(module_name, file)
33 );
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from collections import namedtuple
16 import logging
17 import re
18 import simplejson as json
19
1520 from twisted.internet import defer
1621
1722 from .background_updates import BackgroundUpdateStore
1823 from synapse.api.errors import SynapseError
1924 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
2025
21 import logging
22 import re
23 import ujson as json
24
2526
2627 logger = logging.getLogger(__name__)
28
29 SearchEntry = namedtuple('SearchEntry', [
30 'key', 'value', 'event_id', 'room_id', 'stream_ordering',
31 'origin_server_ts',
32 ])
2733
2834
2935 class SearchStore(BackgroundUpdateStore):
3137 EVENT_SEARCH_UPDATE_NAME = "event_search"
3238 EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
3339 EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
34
35 def __init__(self, hs):
36 super(SearchStore, self).__init__(hs)
40 EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
41
42 def __init__(self, db_conn, hs):
43 super(SearchStore, self).__init__(db_conn, hs)
3744 self.register_background_update_handler(
3845 self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
3946 )
4148 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
4249 self._background_reindex_search_order
4350 )
51
52 # we used to have a background update to turn the GIN index into a
53 # GIST one; we no longer do that (obviously) because we actually want
54 # a GIN index. However, it's possible that some people might still have
55 # the background update queued, so we register a handler to clear the
56 # background update.
57 self.register_noop_background_update(
58 self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
59 )
60
4461 self.register_background_update_handler(
45 self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
46 self._background_reindex_gist_search
62 self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
63 self._background_reindex_gin_search
4764 )
4865
4966 @defer.inlineCallbacks
5067 def _background_reindex_search(self, progress, batch_size):
68 # we work through the events table from highest stream id to lowest
5169 target_min_stream_id = progress["target_min_stream_id_inclusive"]
5270 max_stream_id = progress["max_stream_id_exclusive"]
5371 rows_inserted = progress.get("rows_inserted", 0)
5472
55 INSERT_CLUMP_SIZE = 1000
5673 TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
5774
5875 def reindex_search_txn(txn):
5976 sql = (
60 "SELECT stream_ordering, event_id, room_id, type, content FROM events"
77 "SELECT stream_ordering, event_id, room_id, type, content, "
78 " origin_server_ts FROM events"
6179 " WHERE ? <= stream_ordering AND stream_ordering < ?"
6280 " AND (%s)"
6381 " ORDER BY stream_ordering DESC"
6684
6785 txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
6886
87 # we could stream straight from the results into
88 # store_search_entries_txn with a generator function, but that
89 # would mean having two cursors open on the database at once.
90 # Instead we just build a list of results.
6991 rows = self.cursor_to_dict(txn)
7092 if not rows:
7193 return 0
78100 event_id = row["event_id"]
79101 room_id = row["room_id"]
80102 etype = row["type"]
103 stream_ordering = row["stream_ordering"]
104 origin_server_ts = row["origin_server_ts"]
81105 try:
82106 content = json.loads(row["content"])
83 except:
107 except Exception:
84108 continue
85109
86110 if etype == "m.room.message":
92116 elif etype == "m.room.name":
93117 key = "content.name"
94118 value = content["name"]
119 else:
120 raise Exception("unexpected event type %s" % etype)
95121 except (KeyError, AttributeError):
96122 # If the event is missing a necessary field then
97123 # skip over it.
102128 # then skip over it
103129 continue
104130
105 event_search_rows.append((event_id, room_id, key, value))
106
107 if isinstance(self.database_engine, PostgresEngine):
108 sql = (
109 "INSERT INTO event_search (event_id, room_id, key, vector)"
110 " VALUES (?,?,?,to_tsvector('english', ?))"
111 )
112 elif isinstance(self.database_engine, Sqlite3Engine):
113 sql = (
114 "INSERT INTO event_search (event_id, room_id, key, value)"
115 " VALUES (?,?,?,?)"
116 )
117 else:
118 # This should be unreachable.
119 raise Exception("Unrecognized database engine")
120
121 for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
122 clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
123 txn.executemany(sql, clump)
131 event_search_rows.append(SearchEntry(
132 key=key,
133 value=value,
134 event_id=event_id,
135 room_id=room_id,
136 stream_ordering=stream_ordering,
137 origin_server_ts=origin_server_ts,
138 ))
139
140 self.store_search_entries_txn(txn, event_search_rows)
124141
125142 progress = {
126143 "target_min_stream_id_inclusive": target_min_stream_id,
144161 defer.returnValue(result)
145162
146163 @defer.inlineCallbacks
147 def _background_reindex_gist_search(self, progress, batch_size):
164 def _background_reindex_gin_search(self, progress, batch_size):
165 """This handles old synapses which used GIST indexes, if any;
166 converting them back to be GIN as per the actual schema.
167 """
168
148169 def create_index(conn):
149170 conn.rollback()
171
172 # we have to set autocommit, because postgres refuses to
173 # CREATE INDEX CONCURRENTLY without it.
150174 conn.set_session(autocommit=True)
151 c = conn.cursor()
152
153 c.execute(
154 "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
155 " ON event_search USING GIST (vector)"
156 )
157
158 c.execute("DROP INDEX event_search_fts_idx")
159
160 conn.set_session(autocommit=False)
175
176 try:
177 c = conn.cursor()
178
179 # if we skipped the conversion to GIST, we may already/still
180 # have an event_search_fts_idx; unfortunately postgres 9.4
181 # doesn't support CREATE INDEX IF EXISTS so we just catch the
182 # exception and ignore it.
183 import psycopg2
184 try:
185 c.execute(
186 "CREATE INDEX CONCURRENTLY event_search_fts_idx"
187 " ON event_search USING GIN (vector)"
188 )
189 except psycopg2.ProgrammingError as e:
190 logger.warn(
191 "Ignoring error %r when trying to switch from GIST to GIN",
192 e
193 )
194
195 # we should now be able to delete the GIST index.
196 c.execute(
197 "DROP INDEX IF EXISTS event_search_fts_idx_gist"
198 )
199 finally:
200 conn.set_session(autocommit=False)
161201
162202 if isinstance(self.database_engine, PostgresEngine):
163203 yield self.runWithConnection(create_index)
164204
165 yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
205 yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
166206 defer.returnValue(1)
167207
168208 @defer.inlineCallbacks
240280 yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
241281
242282 defer.returnValue(num_rows)
283
284 def store_event_search_txn(self, txn, event, key, value):
285 """Add event to the search table
286
287 Args:
288 txn (cursor):
289 event (EventBase):
290 key (str):
291 value (str):
292 """
293 self.store_search_entries_txn(
294 txn,
295 (SearchEntry(
296 key=key,
297 value=value,
298 event_id=event.event_id,
299 room_id=event.room_id,
300 stream_ordering=event.internal_metadata.stream_ordering,
301 origin_server_ts=event.origin_server_ts,
302 ),),
303 )
304
305 def store_search_entries_txn(self, txn, entries):
306 """Add entries to the search table
307
308 Args:
309 txn (cursor):
310 entries (iterable[SearchEntry]):
311 entries to be added to the table
312 """
313 if isinstance(self.database_engine, PostgresEngine):
314 sql = (
315 "INSERT INTO event_search"
316 " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
317 " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
318 )
319
320 args = ((
321 entry.event_id, entry.room_id, entry.key, entry.value,
322 entry.stream_ordering, entry.origin_server_ts,
323 ) for entry in entries)
324
325 # inserts to a GIN index are normally batched up into a pending
326 # list, and then all committed together once the list gets to a
327 # certain size. The trouble with that is that postgres (pre-9.5)
328 # uses work_mem to determine the length of the list, and work_mem
329 # is typically very large.
330 #
331 # We therefore reduce work_mem while we do the insert.
332 #
333 # (postgres 9.5 uses the separate gin_pending_list_limit setting,
334 # so doesn't suffer the same problem, but changing work_mem will
335 # be harmless)
336 #
337 # Note that we don't need to worry about restoring it on
338 # exception, because exceptions will cause the transaction to be
339 # rolled back, including the effects of the SET command.
340 #
341 # Also: we use SET rather than SET LOCAL because there's lots of
342 # other stuff going on in this transaction, which want to have the
343 # normal work_mem setting.
344
345 txn.execute("SET work_mem='256kB'")
346 txn.executemany(sql, args)
347 txn.execute("RESET work_mem")
348
349 elif isinstance(self.database_engine, Sqlite3Engine):
350 sql = (
351 "INSERT INTO event_search (event_id, room_id, key, value)"
352 " VALUES (?,?,?,?)"
353 )
354 args = ((
355 entry.event_id, entry.room_id, entry.key, entry.value,
356 ) for entry in entries)
357
358 txn.executemany(sql, args)
359 else:
360 # This should be unreachable.
361 raise Exception("Unrecognized database engine")
243362
244363 @defer.inlineCallbacks
245364 def search_msgs(self, room_ids, search_term, keys):
406525 origin_server_ts, stream = pagination_token.split(",")
407526 origin_server_ts = int(origin_server_ts)
408527 stream = int(stream)
409 except:
528 except Exception:
410529 raise SynapseError(400, "Invalid pagination token")
411530
412531 clauses.append(
2121 from synapse.util.caches.descriptors import cached, cachedList
2222
2323
24 class SignatureStore(SQLBaseStore):
25 """Persistence for event signatures and hashes"""
26
24 class SignatureWorkerStore(SQLBaseStore):
2725 @cached()
2826 def get_event_reference_hash(self, event_id):
29 return self._get_event_reference_hashes_txn(event_id)
27 # This is a dummy function to allow get_event_reference_hashes
28 # to use its cache
29 raise NotImplementedError()
3030
3131 @cachedList(cached_method_name="get_event_reference_hash",
3232 list_name="event_ids", num_args=1)
7373 txn.execute(query, (event_id, ))
7474 return {k: v for k, v in txn}
7575
76
77 class SignatureStore(SignatureWorkerStore):
78 """Persistence for event signatures and hashes"""
79
7680 def _store_event_reference_hashes_txn(self, txn, events):
7781 """Store a hash for a PDU
7882 Args:
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from collections import namedtuple
16 import logging
17
18 from twisted.internet import defer
19
20 from synapse.storage.background_updates import BackgroundUpdateStore
21 from synapse.storage.engines import PostgresEngine
22 from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
23 from synapse.util.caches.descriptors import cached, cachedList
24 from synapse.util.caches.dictionary_cache import DictionaryCache
25 from synapse.util.stringutils import to_ascii
1526 from ._base import SQLBaseStore
16 from synapse.util.caches.descriptors import cached, cachedList
17 from synapse.util.caches import intern_string
18 from synapse.util.stringutils import to_ascii
19 from synapse.storage.engines import PostgresEngine
20
21 from twisted.internet import defer
22 from collections import namedtuple
23
24 import logging
2527
2628 logger = logging.getLogger(__name__)
2729
3941 return len(self.delta_ids) if self.delta_ids else 0
4042
4143
42 class StateStore(SQLBaseStore):
43 """ Keeps track of the state at a given event.
44
45 This is done by the concept of `state groups`. Every event is a assigned
46 a state group (identified by an arbitrary string), which references a
47 collection of state events. The current state of an event is then the
48 collection of state events referenced by the event's state group.
49
50 Hence, every change in the current state causes a new state group to be
51 generated. However, if no change happens (e.g., if we get a message event
52 with only one parent it inherits the state group from its parent.)
53
54 There are three tables:
55 * `state_groups`: Stores group name, first event with in the group and
56 room id.
57 * `event_to_state_groups`: Maps events to state groups.
58 * `state_groups_state`: Maps state group to state events.
44 class StateGroupWorkerStore(SQLBaseStore):
45 """The parts of StateGroupStore that can be called from workers.
5946 """
6047
6148 STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
6249 STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
6350 CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
6451
65 def __init__(self, hs):
66 super(StateStore, self).__init__(hs)
67 self.register_background_update_handler(
68 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
69 self._background_deduplicate_state,
70 )
71 self.register_background_update_handler(
72 self.STATE_GROUP_INDEX_UPDATE_NAME,
73 self._background_index_state,
74 )
75 self.register_background_index_update(
76 self.CURRENT_STATE_INDEX_UPDATE_NAME,
77 index_name="current_state_events_member_index",
78 table="current_state_events",
79 columns=["state_key"],
80 where_clause="type='m.room.member'",
52 def __init__(self, db_conn, hs):
53 super(StateGroupWorkerStore, self).__init__(db_conn, hs)
54
55 self._state_group_cache = DictionaryCache(
56 "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
8157 )
8258
8359 @cached(max_entries=100000, iterable=True)
163139 defer.returnValue(group_to_state)
164140
165141 @defer.inlineCallbacks
142 def get_state_ids_for_group(self, state_group):
143 """Get the state IDs for the given state group
144
145 Args:
146 state_group (int)
147
148 Returns:
149 Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
150 """
151 group_to_state = yield self._get_state_for_groups((state_group,))
152
153 defer.returnValue(group_to_state[state_group])
154
155 @defer.inlineCallbacks
166156 def get_state_groups(self, room_id, event_ids):
167157 """ Get the state groups for the given list of event_ids
168158
188178 ]
189179 for group, event_id_map in group_to_ids.iteritems()
190180 })
191
192 def _have_persisted_state_group_txn(self, txn, state_group):
193 txn.execute(
194 "SELECT count(*) FROM state_groups WHERE id = ?",
195 (state_group,)
196 )
197 row = txn.fetchone()
198 return row and row[0]
199
200 def _store_mult_state_groups_txn(self, txn, events_and_contexts):
201 state_groups = {}
202 for event, context in events_and_contexts:
203 if event.internal_metadata.is_outlier():
204 continue
205
206 if context.current_state_ids is None:
207 # AFAIK, this can never happen
208 logger.error(
209 "Non-outlier event %s had current_state_ids==None",
210 event.event_id)
211 continue
212
213 # if the event was rejected, just give it the same state as its
214 # predecessor.
215 if context.rejected:
216 state_groups[event.event_id] = context.prev_group
217 continue
218
219 state_groups[event.event_id] = context.state_group
220
221 if self._have_persisted_state_group_txn(txn, context.state_group):
222 continue
223
224 self._simple_insert_txn(
225 txn,
226 table="state_groups",
227 values={
228 "id": context.state_group,
229 "room_id": event.room_id,
230 "event_id": event.event_id,
231 },
232 )
233
234 # We persist as a delta if we can, while also ensuring the chain
235 # of deltas isn't tooo long, as otherwise read performance degrades.
236 if context.prev_group:
237 is_in_db = self._simple_select_one_onecol_txn(
238 txn,
239 table="state_groups",
240 keyvalues={"id": context.prev_group},
241 retcol="id",
242 allow_none=True,
243 )
244 if not is_in_db:
245 raise Exception(
246 "Trying to persist state with unpersisted prev_group: %r"
247 % (context.prev_group,)
248 )
249
250 potential_hops = self._count_state_group_hops_txn(
251 txn, context.prev_group
252 )
253 if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
254 self._simple_insert_txn(
255 txn,
256 table="state_group_edges",
257 values={
258 "state_group": context.state_group,
259 "prev_state_group": context.prev_group,
260 },
261 )
262
263 self._simple_insert_many_txn(
264 txn,
265 table="state_groups_state",
266 values=[
267 {
268 "state_group": context.state_group,
269 "room_id": event.room_id,
270 "type": key[0],
271 "state_key": key[1],
272 "event_id": state_id,
273 }
274 for key, state_id in context.delta_ids.iteritems()
275 ],
276 )
277 else:
278 self._simple_insert_many_txn(
279 txn,
280 table="state_groups_state",
281 values=[
282 {
283 "state_group": context.state_group,
284 "room_id": event.room_id,
285 "type": key[0],
286 "state_key": key[1],
287 "event_id": state_id,
288 }
289 for key, state_id in context.current_state_ids.iteritems()
290 ],
291 )
292
293 # Prefill the state group cache with this group.
294 # It's fine to use the sequence like this as the state group map
295 # is immutable. (If the map wasn't immutable then this prefill could
296 # race with another update)
297 txn.call_after(
298 self._state_group_cache.update,
299 self._state_group_cache.sequence,
300 key=context.state_group,
301 value=dict(context.current_state_ids),
302 full=True,
303 )
304
305 self._simple_insert_many_txn(
306 txn,
307 table="event_to_state_groups",
308 values=[
309 {
310 "state_group": state_group_id,
311 "event_id": event_id,
312 }
313 for event_id, state_group_id in state_groups.iteritems()
314 ],
315 )
316
317 for event_id, state_group_id in state_groups.iteritems():
318 txn.call_after(
319 self._get_state_group_for_event.prefill,
320 (event_id,), state_group_id
321 )
322
323 def _count_state_group_hops_txn(self, txn, state_group):
324 """Given a state group, count how many hops there are in the tree.
325
326 This is used to ensure the delta chains don't get too long.
327 """
328 if isinstance(self.database_engine, PostgresEngine):
329 sql = ("""
330 WITH RECURSIVE state(state_group) AS (
331 VALUES(?::bigint)
332 UNION ALL
333 SELECT prev_state_group FROM state_group_edges e, state s
334 WHERE s.state_group = e.state_group
335 )
336 SELECT count(*) FROM state;
337 """)
338
339 txn.execute(sql, (state_group,))
340 row = txn.fetchone()
341 if row and row[0]:
342 return row[0]
343 else:
344 return 0
345 else:
346 # We don't use WITH RECURSIVE on sqlite3 as there are distributions
347 # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
348 next_group = state_group
349 count = 0
350
351 while next_group:
352 next_group = self._simple_select_one_onecol_txn(
353 txn,
354 table="state_group_edges",
355 keyvalues={"state_group": next_group},
356 retcol="prev_state_group",
357 allow_none=True,
358 )
359 if next_group:
360 count += 1
361
362 return count
363181
364182 @defer.inlineCallbacks
365183 def _get_state_groups_from_groups(self, groups, types):
421239 (
422240 "AND type = ? AND state_key = ?",
423241 (etype, state_key)
242 ) if state_key is not None else (
243 "AND type = ?",
244 (etype,)
424245 )
425246 for etype, state_key in types
426247 ]
440261 key = (typ, state_key)
441262 results[group][key] = event_id
442263 else:
264 where_args = []
265 where_clauses = []
266 wildcard_types = False
443267 if types is not None:
444 where_clause = "AND (%s)" % (
445 " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
446 )
268 for typ in types:
269 if typ[1] is None:
270 where_clauses.append("(type = ?)")
271 where_args.extend(typ[0])
272 wildcard_types = True
273 else:
274 where_clauses.append("(type = ? AND state_key = ?)")
275 where_args.extend([typ[0], typ[1]])
276 where_clause = "AND (%s)" % (" OR ".join(where_clauses))
447277 else:
448278 where_clause = ""
449279
460290 # after we finish deduping state, which requires this func)
461291 args = [next_group]
462292 if types:
463 args.extend(i for typ in types for i in typ)
293 args.extend(where_args)
464294
465295 txn.execute(
466296 "SELECT type, state_key, event_id FROM state_groups_state"
473303 if (typ, state_key) not in results[group]
474304 )
475305
476 # If the lengths match then we must have all the types,
477 # so no need to go walk further down the tree.
478 if types is not None and len(results[group]) == len(types):
306 # If the number of entries in the (type,state_key)->event_id dict
307 # matches the number of (type,state_keys) types we were searching
308 # for, then we must have found them all, so no need to go walk
309 # further down the tree... UNLESS our types filter contained
310 # wildcards (i.e. Nones) in which case we have to do an exhaustive
311 # search
312 if (
313 types is not None and
314 not wildcard_types and
315 len(results[group]) == len(types)
316 ):
479317 break
480318
481319 next_group = self._simple_select_one_onecol_txn(
741579
742580 defer.returnValue(results)
743581
744 def get_next_state_group(self):
745 return self._state_groups_id_gen.get_next()
582 def store_state_group(self, event_id, room_id, prev_group, delta_ids,
583 current_state_ids):
584 """Store a new set of state, returning a newly assigned state group.
585
586 Args:
587 event_id (str): The event ID for which the state was calculated
588 room_id (str)
589 prev_group (int|None): A previous state group for the room, optional.
590 delta_ids (dict|None): The delta between state at `prev_group` and
591 `current_state_ids`, if `prev_group` was given. Same format as
592 `current_state_ids`.
593 current_state_ids (dict): The state to store. Map of (type, state_key)
594 to event_id.
595
596 Returns:
597 Deferred[int]: The state group ID
598 """
599 def _store_state_group_txn(txn):
600 if current_state_ids is None:
601 # AFAIK, this can never happen
602 raise Exception("current_state_ids cannot be None")
603
604 state_group = self.database_engine.get_next_state_group_id(txn)
605
606 self._simple_insert_txn(
607 txn,
608 table="state_groups",
609 values={
610 "id": state_group,
611 "room_id": room_id,
612 "event_id": event_id,
613 },
614 )
615
616 # We persist as a delta if we can, while also ensuring the chain
617 # of deltas isn't tooo long, as otherwise read performance degrades.
618 if prev_group:
619 is_in_db = self._simple_select_one_onecol_txn(
620 txn,
621 table="state_groups",
622 keyvalues={"id": prev_group},
623 retcol="id",
624 allow_none=True,
625 )
626 if not is_in_db:
627 raise Exception(
628 "Trying to persist state with unpersisted prev_group: %r"
629 % (prev_group,)
630 )
631
632 potential_hops = self._count_state_group_hops_txn(
633 txn, prev_group
634 )
635 if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
636 self._simple_insert_txn(
637 txn,
638 table="state_group_edges",
639 values={
640 "state_group": state_group,
641 "prev_state_group": prev_group,
642 },
643 )
644
645 self._simple_insert_many_txn(
646 txn,
647 table="state_groups_state",
648 values=[
649 {
650 "state_group": state_group,
651 "room_id": room_id,
652 "type": key[0],
653 "state_key": key[1],
654 "event_id": state_id,
655 }
656 for key, state_id in delta_ids.iteritems()
657 ],
658 )
659 else:
660 self._simple_insert_many_txn(
661 txn,
662 table="state_groups_state",
663 values=[
664 {
665 "state_group": state_group,
666 "room_id": room_id,
667 "type": key[0],
668 "state_key": key[1],
669 "event_id": state_id,
670 }
671 for key, state_id in current_state_ids.iteritems()
672 ],
673 )
674
675 # Prefill the state group cache with this group.
676 # It's fine to use the sequence like this as the state group map
677 # is immutable. (If the map wasn't immutable then this prefill could
678 # race with another update)
679 txn.call_after(
680 self._state_group_cache.update,
681 self._state_group_cache.sequence,
682 key=state_group,
683 value=dict(current_state_ids),
684 full=True,
685 )
686
687 return state_group
688
689 return self.runInteraction("store_state_group", _store_state_group_txn)
690
691 def _count_state_group_hops_txn(self, txn, state_group):
692 """Given a state group, count how many hops there are in the tree.
693
694 This is used to ensure the delta chains don't get too long.
695 """
696 if isinstance(self.database_engine, PostgresEngine):
697 sql = ("""
698 WITH RECURSIVE state(state_group) AS (
699 VALUES(?::bigint)
700 UNION ALL
701 SELECT prev_state_group FROM state_group_edges e, state s
702 WHERE s.state_group = e.state_group
703 )
704 SELECT count(*) FROM state;
705 """)
706
707 txn.execute(sql, (state_group,))
708 row = txn.fetchone()
709 if row and row[0]:
710 return row[0]
711 else:
712 return 0
713 else:
714 # We don't use WITH RECURSIVE on sqlite3 as there are distributions
715 # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
716 next_group = state_group
717 count = 0
718
719 while next_group:
720 next_group = self._simple_select_one_onecol_txn(
721 txn,
722 table="state_group_edges",
723 keyvalues={"state_group": next_group},
724 retcol="prev_state_group",
725 allow_none=True,
726 )
727 if next_group:
728 count += 1
729
730 return count
731
732
733 class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
734 """ Keeps track of the state at a given event.
735
736 This is done by the concept of `state groups`. Every event is a assigned
737 a state group (identified by an arbitrary string), which references a
738 collection of state events. The current state of an event is then the
739 collection of state events referenced by the event's state group.
740
741 Hence, every change in the current state causes a new state group to be
742 generated. However, if no change happens (e.g., if we get a message event
743 with only one parent it inherits the state group from its parent.)
744
745 There are three tables:
746 * `state_groups`: Stores group name, first event with in the group and
747 room id.
748 * `event_to_state_groups`: Maps events to state groups.
749 * `state_groups_state`: Maps state group to state events.
750 """
751
752 STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
753 STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
754 CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
755
756 def __init__(self, db_conn, hs):
757 super(StateStore, self).__init__(db_conn, hs)
758 self.register_background_update_handler(
759 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
760 self._background_deduplicate_state,
761 )
762 self.register_background_update_handler(
763 self.STATE_GROUP_INDEX_UPDATE_NAME,
764 self._background_index_state,
765 )
766 self.register_background_index_update(
767 self.CURRENT_STATE_INDEX_UPDATE_NAME,
768 index_name="current_state_events_member_index",
769 table="current_state_events",
770 columns=["state_key"],
771 where_clause="type='m.room.member'",
772 )
773
774 def _store_event_state_mappings_txn(self, txn, events_and_contexts):
775 state_groups = {}
776 for event, context in events_and_contexts:
777 if event.internal_metadata.is_outlier():
778 continue
779
780 # if the event was rejected, just give it the same state as its
781 # predecessor.
782 if context.rejected:
783 state_groups[event.event_id] = context.prev_group
784 continue
785
786 state_groups[event.event_id] = context.state_group
787
788 self._simple_insert_many_txn(
789 txn,
790 table="event_to_state_groups",
791 values=[
792 {
793 "state_group": state_group_id,
794 "event_id": event_id,
795 }
796 for event_id, state_group_id in state_groups.iteritems()
797 ],
798 )
799
800 for event_id, state_group_id in state_groups.iteritems():
801 txn.call_after(
802 self._get_state_group_for_event.prefill,
803 (event_id,), state_group_id
804 )
746805
747806 @defer.inlineCallbacks
748807 def _background_deduplicate_state(self, progress, batch_size):
3434
3535 from twisted.internet import defer
3636
37 from ._base import SQLBaseStore
37 from synapse.storage._base import SQLBaseStore
38 from synapse.storage.events import EventsWorkerStore
39
3840 from synapse.util.caches.descriptors import cached
39 from synapse.api.constants import EventTypes
4041 from synapse.types import RoomStreamToken
41 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
42 from synapse.util.caches.stream_change_cache import StreamChangeCache
43 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
4244 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
4345
46 import abc
4447 import logging
4548
4649
142145 return " AND ".join(clauses), args
143146
144147
145 class StreamStore(SQLBaseStore):
146 @defer.inlineCallbacks
147 def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
148 # NB this lives here instead of appservice.py so we can reuse the
149 # 'private' StreamToken class in this file.
150 if limit:
151 limit = max(limit, MAX_STREAM_SIZE)
152 else:
153 limit = MAX_STREAM_SIZE
154
155 # From and to keys should be integers from ordering.
156 from_id = RoomStreamToken.parse_stream_token(from_key)
157 to_id = RoomStreamToken.parse_stream_token(to_key)
158
159 if from_key == to_key:
160 defer.returnValue(([], to_key))
161 return
162
163 # select all the events between from/to with a sensible limit
164 sql = (
165 "SELECT e.event_id, e.room_id, e.type, s.state_key, "
166 "e.stream_ordering FROM events AS e "
167 "LEFT JOIN state_events as s ON "
168 "e.event_id = s.event_id "
169 "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
170 "ORDER BY stream_ordering ASC LIMIT %(limit)d "
171 ) % {
172 "limit": limit
173 }
174
175 def f(txn):
176 # pull out all the events between the tokens
177 txn.execute(sql, (from_id.stream, to_id.stream,))
178 rows = self.cursor_to_dict(txn)
179
180 # Logic:
181 # - We want ALL events which match the AS room_id regex
182 # - We want ALL events which match the rooms represented by the AS
183 # room_alias regex
184 # - We want ALL events for rooms that AS users have joined.
185 # This is currently supported via get_app_service_rooms (which is
186 # used for the Notifier listener rooms). We can't reasonably make a
187 # SQL query for these room IDs, so we'll pull all the events between
188 # from/to and filter in python.
189 rooms_for_as = self._get_app_service_rooms_txn(txn, service)
190 room_ids_for_as = [r.room_id for r in rooms_for_as]
191
192 def app_service_interested(row):
193 if row["room_id"] in room_ids_for_as:
194 return True
195
196 if row["type"] == EventTypes.Member:
197 if service.is_interested_in_user(row.get("state_key")):
198 return True
199 return False
200
201 return [r for r in rows if app_service_interested(r)]
202
203 rows = yield self.runInteraction("get_appservice_room_stream", f)
204
205 ret = yield self._get_events(
206 [r["event_id"] for r in rows],
207 get_prev_content=True
208 )
209
210 self._set_before_and_after(ret, rows, topo_order=from_id is None)
211
212 if rows:
213 key = "s%d" % max(r["stream_ordering"] for r in rows)
214 else:
215 # Assume we didn't get anything because there was nothing to
216 # get.
217 key = to_key
218
219 defer.returnValue((ret, key))
148 class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
149 """This is an abstract base class where subclasses must implement
150 `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
151 which can be called in the initializer.
152 """
153
154 __metaclass__ = abc.ABCMeta
155
156 def __init__(self, db_conn, hs):
157 super(StreamWorkerStore, self).__init__(db_conn, hs)
158
159 events_max = self.get_room_max_stream_ordering()
160 event_cache_prefill, min_event_val = self._get_cache_dict(
161 db_conn, "events",
162 entity_column="room_id",
163 stream_column="stream_ordering",
164 max_value=events_max,
165 )
166 self._events_stream_cache = StreamChangeCache(
167 "EventsRoomStreamChangeCache", min_event_val,
168 prefilled_cache=event_cache_prefill,
169 )
170 self._membership_stream_cache = StreamChangeCache(
171 "MembershipStreamChangeCache", events_max,
172 )
173
174 self._stream_order_on_start = self.get_room_max_stream_ordering()
175
176 @abc.abstractmethod
177 def get_room_max_stream_ordering(self):
178 raise NotImplementedError()
179
180 @abc.abstractmethod
181 def get_room_min_stream_ordering(self):
182 raise NotImplementedError()
220183
221184 @defer.inlineCallbacks
222185 def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
233196 results = {}
234197 room_ids = list(room_ids)
235198 for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
236 res = yield preserve_context_over_deferred(defer.gatherResults([
199 res = yield make_deferred_yieldable(defer.gatherResults([
237200 preserve_fn(self.get_room_events_stream_for_room)(
238201 room_id, from_key, to_key, limit, order=order,
239202 )
380343 defer.returnValue(ret)
381344
382345 @defer.inlineCallbacks
383 def paginate_room_events(self, room_id, from_key, to_key=None,
384 direction='b', limit=-1, event_filter=None):
385 # Tokens really represent positions between elements, but we use
386 # the convention of pointing to the event before the gap. Hence
387 # we have a bit of asymmetry when it comes to equalities.
388 args = [False, room_id]
389 if direction == 'b':
390 order = "DESC"
391 bounds = upper_bound(
392 RoomStreamToken.parse(from_key), self.database_engine
393 )
394 if to_key:
395 bounds = "%s AND %s" % (bounds, lower_bound(
396 RoomStreamToken.parse(to_key), self.database_engine
397 ))
398 else:
399 order = "ASC"
400 bounds = lower_bound(
401 RoomStreamToken.parse(from_key), self.database_engine
402 )
403 if to_key:
404 bounds = "%s AND %s" % (bounds, upper_bound(
405 RoomStreamToken.parse(to_key), self.database_engine
406 ))
407
408 filter_clause, filter_args = filter_to_clause(event_filter)
409
410 if filter_clause:
411 bounds += " AND " + filter_clause
412 args.extend(filter_args)
413
414 if int(limit) > 0:
415 args.append(int(limit))
416 limit_str = " LIMIT ?"
417 else:
418 limit_str = ""
419
420 sql = (
421 "SELECT * FROM events"
422 " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
423 " ORDER BY topological_ordering %(order)s,"
424 " stream_ordering %(order)s %(limit)s"
425 ) % {
426 "bounds": bounds,
427 "order": order,
428 "limit": limit_str
429 }
430
431 def f(txn):
432 txn.execute(sql, args)
433
434 rows = self.cursor_to_dict(txn)
435
436 if rows:
437 topo = rows[-1]["topological_ordering"]
438 toke = rows[-1]["stream_ordering"]
439 if direction == 'b':
440 # Tokens are positions between events.
441 # This token points *after* the last event in the chunk.
442 # We need it to point to the event before it in the chunk
443 # when we are going backwards so we subtract one from the
444 # stream part.
445 toke -= 1
446 next_token = str(RoomStreamToken(topo, toke))
447 else:
448 # TODO (erikj): We should work out what to do here instead.
449 next_token = to_key if to_key else from_key
450
451 return rows, next_token,
452
453 rows, token = yield self.runInteraction("paginate_room_events", f)
454
455 events = yield self._get_events(
456 [r["event_id"] for r in rows],
457 get_prev_content=True
458 )
459
460 self._set_before_and_after(events, rows)
461
462 defer.returnValue((events, token))
463
464 @defer.inlineCallbacks
465346 def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
466347 rows, token = yield self.get_recent_event_ids_for_room(
467348 room_id, limit, end_token, from_token
533414 "get_recent_events_for_room", get_recent_events_for_room_txn
534415 )
535416
417 def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
418 """Gets details of the first event in a room at or after a stream ordering
419
420 Args:
421 room_id (str):
422 stream_ordering (int):
423
424 Returns:
425 Deferred[(int, int, str)]:
426 (stream ordering, topological ordering, event_id)
427 """
428 def _f(txn):
429 sql = (
430 "SELECT stream_ordering, topological_ordering, event_id"
431 " FROM events"
432 " WHERE room_id = ? AND stream_ordering >= ?"
433 " AND NOT outlier"
434 " ORDER BY stream_ordering"
435 " LIMIT 1"
436 )
437 txn.execute(sql, (room_id, stream_ordering, ))
438 return txn.fetchone()
439
440 return self.runInteraction(
441 "get_room_event_after_stream_ordering", _f,
442 )
443
536444 @defer.inlineCallbacks
537445 def get_room_events_max_id(self, room_id=None):
538446 """Returns the current token for rooms stream.
541449 `room_id` causes it to return the current room specific topological
542450 token.
543451 """
544 token = yield self._stream_id_gen.get_current_token()
452 token = yield self.get_room_max_stream_ordering()
545453 if room_id is None:
546454 defer.returnValue("s%d" % (token,))
547455 else:
550458 room_id,
551459 )
552460 defer.returnValue("t%d-%d" % (topo, token))
553
554 def get_room_max_stream_ordering(self):
555 return self._stream_id_gen.get_current_token()
556
557 def get_room_min_stream_ordering(self):
558 return self._backfill_id_gen.get_current_token()
559461
560462 def get_stream_token_for_event(self, event_id):
561463 """The stream token for an event
831733
832734 def has_room_changed_since(self, room_id, stream_id):
833735 return self._events_stream_cache.has_entity_changed(room_id, stream_id)
736
737
738 class StreamStore(StreamWorkerStore):
739 def get_room_max_stream_ordering(self):
740 return self._stream_id_gen.get_current_token()
741
742 def get_room_min_stream_ordering(self):
743 return self._backfill_id_gen.get_current_token()
744
745 @defer.inlineCallbacks
746 def paginate_room_events(self, room_id, from_key, to_key=None,
747 direction='b', limit=-1, event_filter=None):
748 # Tokens really represent positions between elements, but we use
749 # the convention of pointing to the event before the gap. Hence
750 # we have a bit of asymmetry when it comes to equalities.
751 args = [False, room_id]
752 if direction == 'b':
753 order = "DESC"
754 bounds = upper_bound(
755 RoomStreamToken.parse(from_key), self.database_engine
756 )
757 if to_key:
758 bounds = "%s AND %s" % (bounds, lower_bound(
759 RoomStreamToken.parse(to_key), self.database_engine
760 ))
761 else:
762 order = "ASC"
763 bounds = lower_bound(
764 RoomStreamToken.parse(from_key), self.database_engine
765 )
766 if to_key:
767 bounds = "%s AND %s" % (bounds, upper_bound(
768 RoomStreamToken.parse(to_key), self.database_engine
769 ))
770
771 filter_clause, filter_args = filter_to_clause(event_filter)
772
773 if filter_clause:
774 bounds += " AND " + filter_clause
775 args.extend(filter_args)
776
777 if int(limit) > 0:
778 args.append(int(limit))
779 limit_str = " LIMIT ?"
780 else:
781 limit_str = ""
782
783 sql = (
784 "SELECT * FROM events"
785 " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
786 " ORDER BY topological_ordering %(order)s,"
787 " stream_ordering %(order)s %(limit)s"
788 ) % {
789 "bounds": bounds,
790 "order": order,
791 "limit": limit_str
792 }
793
794 def f(txn):
795 txn.execute(sql, args)
796
797 rows = self.cursor_to_dict(txn)
798
799 if rows:
800 topo = rows[-1]["topological_ordering"]
801 toke = rows[-1]["stream_ordering"]
802 if direction == 'b':
803 # Tokens are positions between events.
804 # This token points *after* the last event in the chunk.
805 # We need it to point to the event before it in the chunk
806 # when we are going backwards so we subtract one from the
807 # stream part.
808 toke -= 1
809 next_token = str(RoomStreamToken(topo, toke))
810 else:
811 # TODO (erikj): We should work out what to do here instead.
812 next_token = to_key if to_key else from_key
813
814 return rows, next_token,
815
816 rows, token = yield self.runInteraction("paginate_room_events", f)
817
818 events = yield self._get_events(
819 [r["event_id"] for r in rows],
820 get_prev_content=True
821 )
822
823 self._set_before_and_after(events, rows)
824
825 defer.returnValue((events, token))
00 # -*- coding: utf-8 -*-
11 # Copyright 2014-2016 OpenMarket Ltd
2 # Copyright 2018 New Vector Ltd
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1213 # See the License for the specific language governing permissions and
1314 # limitations under the License.
1415
15 from ._base import SQLBaseStore
16 from synapse.storage.account_data import AccountDataWorkerStore
17
1618 from synapse.util.caches.descriptors import cached
1719 from twisted.internet import defer
1820
19 import ujson as json
21 import simplejson as json
2022 import logging
2123
2224 logger = logging.getLogger(__name__)
2325
2426
25 class TagsStore(SQLBaseStore):
26 def get_max_account_data_stream_id(self):
27 """Get the current max stream id for the private user data stream
28
29 Returns:
30 A deferred int.
31 """
32 return self._account_data_id_gen.get_current_token()
33
27 class TagsWorkerStore(AccountDataWorkerStore):
3428 @cached()
3529 def get_tags_for_user(self, user_id):
3630 """Get all the tags for a user.
169163 row["tag"]: json.loads(row["content"]) for row in rows
170164 })
171165
166
167 class TagsStore(TagsWorkerStore):
172168 @defer.inlineCallbacks
173169 def add_tag_to_room(self, user_id, room_id, tag, content):
174170 """Add a tag to a room for a user.
2222 from collections import namedtuple
2323
2424 import logging
25 import ujson as json
25 import simplejson as json
2626
2727 logger = logging.getLogger(__name__)
2828
4545 """A collection of queries for handling PDUs.
4646 """
4747
48 def __init__(self, hs):
49 super(TransactionStore, self).__init__(hs)
48 def __init__(self, db_conn, hs):
49 super(TransactionStore, self).__init__(db_conn, hs)
5050
5151 self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
5252
6262 user_ids (list(str)): Users to add
6363 """
6464 yield self._simple_insert_many(
65 table="users_in_pubic_room",
65 table="users_in_public_rooms",
6666 values=[
6767 {
6868 "user_id": user_id,
163163 )
164164
165165 if isinstance(self.database_engine, PostgresEngine):
166 # We weight the loclpart most highly, then display name and finally
166 # We weight the localpart most highly, then display name and finally
167167 # server name
168168 if new_entry:
169169 sql = """
218218 @defer.inlineCallbacks
219219 def update_user_in_public_user_list(self, user_id, room_id):
220220 yield self._simple_update_one(
221 table="users_in_pubic_room",
221 table="users_in_public_rooms",
222222 keyvalues={"user_id": user_id},
223223 updatevalues={"room_id": room_id},
224224 desc="update_user_in_public_user_list",
239239 )
240240 self._simple_delete_txn(
241241 txn,
242 table="users_in_pubic_room",
242 table="users_in_public_rooms",
243243 keyvalues={"user_id": user_id},
244244 )
245245 txn.call_after(
255255 @defer.inlineCallbacks
256256 def remove_from_user_in_public_room(self, user_id):
257257 yield self._simple_delete(
258 table="users_in_pubic_room",
258 table="users_in_public_rooms",
259259 keyvalues={"user_id": user_id},
260260 desc="remove_from_user_in_public_room",
261261 )
266266 in the given room_id
267267 """
268268 return self._simple_select_onecol(
269 table="users_in_pubic_room",
269 table="users_in_public_rooms",
270270 keyvalues={"room_id": room_id},
271271 retcol="user_id",
272272 desc="get_users_in_public_due_to_room",
285285 )
286286
287287 user_ids_pub = yield self._simple_select_onecol(
288 table="users_in_pubic_room",
288 table="users_in_public_rooms",
289289 keyvalues={"room_id": room_id},
290290 retcol="user_id",
291291 desc="get_users_in_dir_due_to_room",
315315 """
316316 rows = yield self._execute("get_all_rooms", None, sql)
317317 defer.returnValue([room_id for room_id, in rows])
318
319 @defer.inlineCallbacks
320 def get_all_local_users(self):
321 """Get all local users
322 """
323 sql = """
324 SELECT name FROM users
325 """
326 rows = yield self._execute("get_all_local_users", None, sql)
327 defer.returnValue([name for name, in rows])
318328
319329 def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
320330 """Insert entries into the users_who_share_rooms table. The first
513523 def _delete_all_from_user_dir_txn(txn):
514524 txn.execute("DELETE FROM user_directory")
515525 txn.execute("DELETE FROM user_directory_search")
516 txn.execute("DELETE FROM users_in_pubic_room")
526 txn.execute("DELETE FROM users_in_public_rooms")
517527 txn.execute("DELETE FROM users_who_share_rooms")
518528 txn.call_after(self.get_user_in_directory.invalidate_all)
519529 txn.call_after(self.get_user_in_public_room.invalidate_all)
536546 @cached()
537547 def get_user_in_public_room(self, user_id):
538548 return self._simple_select_one(
539 table="users_in_pubic_room",
549 table="users_in_public_rooms",
540550 keyvalues={"user_id": user_id},
541551 retcols=("room_id",),
542552 allow_none=True,
628638 ]
629639 }
630640 """
641
642 if self.hs.config.user_directory_search_all_users:
643 # make s.user_id null to keep the ordering algorithm happy
644 join_clause = """
645 CROSS JOIN (SELECT NULL as user_id) AS s
646 """
647 join_args = ()
648 where_clause = "1=1"
649 else:
650 join_clause = """
651 LEFT JOIN users_in_public_rooms AS p USING (user_id)
652 LEFT JOIN (
653 SELECT other_user_id AS user_id FROM users_who_share_rooms
654 WHERE user_id = ? AND share_private
655 ) AS s USING (user_id)
656 """
657 join_args = (user_id,)
658 where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
659
631660 if isinstance(self.database_engine, PostgresEngine):
632661 full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
633662
640669 SELECT d.user_id, display_name, avatar_url
641670 FROM user_directory_search
642671 INNER JOIN user_directory AS d USING (user_id)
643 LEFT JOIN users_in_pubic_room AS p USING (user_id)
644 LEFT JOIN (
645 SELECT other_user_id AS user_id FROM users_who_share_rooms
646 WHERE user_id = ? AND share_private
647 ) AS s USING (user_id)
672 %s
648673 WHERE
649 (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
674 %s
650675 AND vector @@ to_tsquery('english', ?)
651676 ORDER BY
652677 (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
670695 display_name IS NULL,
671696 avatar_url IS NULL
672697 LIMIT ?
673 """
674 args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
698 """ % (join_clause, where_clause)
699 args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
675700 elif isinstance(self.database_engine, Sqlite3Engine):
676701 search_query = _parse_query_sqlite(search_term)
677702
679704 SELECT d.user_id, display_name, avatar_url
680705 FROM user_directory_search
681706 INNER JOIN user_directory AS d USING (user_id)
682 LEFT JOIN users_in_pubic_room AS p USING (user_id)
683 LEFT JOIN (
684 SELECT other_user_id AS user_id FROM users_who_share_rooms
685 WHERE user_id = ? AND share_private
686 ) AS s USING (user_id)
707 %s
687708 WHERE
688 (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
709 %s
689710 AND value MATCH ?
690711 ORDER BY
691712 rank(matchinfo(user_directory_search)) DESC,
692713 display_name IS NULL,
693714 avatar_url IS NULL
694715 LIMIT ?
695 """
696 args = (user_id, search_query, limit + 1)
716 """ % (join_clause, where_clause)
717 args = join_args + (search_query, limit + 1)
697718 else:
698719 # This should be unreachable.
699720 raise Exception("Unrecognized database engine")
722743
723744 # Pull out the individual words, discarding any non-word characters.
724745 results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
725 return " & ".join("(%s* | %s)" % (result, result,) for result in results)
746 return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
726747
727748
728749 def _parse_query_postgres(search_term):
7979 from_tok = None # For backwards compat.
8080 elif from_tok:
8181 from_tok = StreamToken.from_string(from_tok)
82 except:
82 except Exception:
8383 raise SynapseError(400, "'from' paramater is invalid")
8484
8585 try:
8686 if to_tok:
8787 to_tok = StreamToken.from_string(to_tok)
88 except:
88 except Exception:
8989 raise SynapseError(400, "'to' paramater is invalid")
9090
9191 limit = get_param("limit", None)
9797
9898 try:
9999 return PaginationConfig(from_tok, to_tok, direction, limit)
100 except:
100 except Exception:
101101 logger.exception("Failed to create pagination config")
102102 raise SynapseError(400, "Invalid request.")
103103
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import string
1415
1516 from synapse.api.errors import SynapseError
1617
1718 from collections import namedtuple
1819
1920
20 Requester = namedtuple("Requester", [
21 class Requester(namedtuple("Requester", [
2122 "user", "access_token_id", "is_guest", "device_id", "app_service",
22 ])
23 """
24 Represents the user making a request
25
26 Attributes:
27 user (UserID): id of the user making the request
28 access_token_id (int|None): *ID* of the access token used for this
29 request, or None if it came via the appservice API or similar
30 is_guest (bool): True if the user making this request is a guest user
31 device_id (str|None): device_id which was set at authentication time
32 app_service (ApplicationService|None): the AS requesting on behalf of the user
33 """
23 ])):
24 """
25 Represents the user making a request
26
27 Attributes:
28 user (UserID): id of the user making the request
29 access_token_id (int|None): *ID* of the access token used for this
30 request, or None if it came via the appservice API or similar
31 is_guest (bool): True if the user making this request is a guest user
32 device_id (str|None): device_id which was set at authentication time
33 app_service (ApplicationService|None): the AS requesting on behalf of the user
34 """
35
36 def serialize(self):
37 """Converts self to a type that can be serialized as JSON, and then
38 deserialized by `deserialize`
39
40 Returns:
41 dict
42 """
43 return {
44 "user_id": self.user.to_string(),
45 "access_token_id": self.access_token_id,
46 "is_guest": self.is_guest,
47 "device_id": self.device_id,
48 "app_server_id": self.app_service.id if self.app_service else None,
49 }
50
51 @staticmethod
52 def deserialize(store, input):
53 """Converts a dict that was produced by `serialize` back into a
54 Requester.
55
56 Args:
57 store (DataStore): Used to convert AS ID to AS object
58 input (dict): A dict produced by `serialize`
59
60 Returns:
61 Requester
62 """
63 appservice = None
64 if input["app_server_id"]:
65 appservice = store.get_app_service_by_id(input["app_server_id"])
66
67 return Requester(
68 user=UserID.from_string(input["user_id"]),
69 access_token_id=input["access_token_id"],
70 is_guest=input["is_guest"],
71 device_id=input["device_id"],
72 app_service=appservice,
73 )
3474
3575
3676 def create_requester(user_id, access_token_id=None, is_guest=False,
125165 try:
126166 cls.from_string(s)
127167 return True
128 except:
168 except Exception:
129169 return False
130170
131171 __str__ = to_string
132
133 @classmethod
134 def create(cls, localpart, domain,):
135 return cls(localpart=localpart, domain=domain)
136172
137173
138174 class UserID(DomainSpecificString):
158194 class GroupID(DomainSpecificString):
159195 """Structure representing a group ID."""
160196 SIGIL = "+"
197
198 @classmethod
199 def from_string(cls, s):
200 group_id = super(GroupID, cls).from_string(s)
201 if not group_id.localpart:
202 raise SynapseError(
203 400,
204 "Group ID cannot be empty",
205 )
206
207 if contains_invalid_mxid_characters(group_id.localpart):
208 raise SynapseError(
209 400,
210 "Group ID can only contain characters a-z, 0-9, or '=_-./'",
211 )
212
213 return group_id
214
215
216 mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
217
218
219 def contains_invalid_mxid_characters(localpart):
220 """Check for characters not allowed in an mxid or groupid localpart
221
222 Args:
223 localpart (basestring): the localpart to be checked
224
225 Returns:
226 bool: True if there are any naughty characters
227 """
228 return any(c not in mxid_localpart_allowed_characters for c in localpart)
161229
162230
163231 class StreamToken(
183251 # i.e. old token from before receipt_key
184252 keys.append("0")
185253 return cls(*keys)
186 except:
254 except Exception:
187255 raise SynapseError(400, "Invalid Token")
188256
189257 def to_string(self):
269337 if string[0] == 't':
270338 parts = string[1:].split('-', 1)
271339 return cls(topological=int(parts[0]), stream=int(parts[1]))
272 except:
340 except Exception:
273341 pass
274342 raise SynapseError(400, "Invalid token %r" % (string,))
275343
278346 try:
279347 if string[0] == 's':
280348 return cls(topological=None, stream=int(string[1:]))
281 except:
349 except Exception:
282350 pass
283351 raise SynapseError(400, "Invalid token %r" % (string,))
284352
5858 f(function): The function to call repeatedly.
5959 msec(float): How long to wait between calls in milliseconds.
6060 """
61 l = task.LoopingCall(f)
62 l.start(msec / 1000.0, now=False)
63 return l
61 call = task.LoopingCall(f)
62 call.start(msec / 1000.0, now=False)
63 return call
6464
6565 def call_later(self, delay, callback, *args, **kwargs):
6666 """Call something later
8181 def cancel_call_later(self, timer, ignore_errs=False):
8282 try:
8383 timer.cancel()
84 except:
84 except Exception:
8585 if not ignore_errs:
8686 raise
8787
9696
9797 try:
9898 ret_deferred.errback(e)
99 except:
99 except Exception:
100100 pass
101101
102102 try:
103103 given_deferred.cancel()
104 except:
104 except Exception:
105105 pass
106106
107107 timer = None
109109 def cancel(res):
110110 try:
111111 self.cancel_call_later(timer)
112 except:
112 except Exception:
113113 pass
114114 return res
115115
118118 def success(res):
119119 try:
120120 ret_deferred.callback(res)
121 except:
121 except Exception:
122122 pass
123123
124124 return res
126126 def err(res):
127127 try:
128128 ret_deferred.errback(res)
129 except:
129 except Exception:
130130 pass
131131
132132 given_deferred.addCallbacks(callback=success, errback=err)
1616 from twisted.internet import defer, reactor
1717
1818 from .logcontext import (
19 PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
19 PreserveLoggingContext, make_deferred_yieldable, preserve_fn
2020 )
2121 from synapse.util import logcontext, unwrapFirstError
2222
7272 try:
7373 # TODO: Handle errors here.
7474 self._observers.pop().callback(r)
75 except:
75 except Exception:
7676 pass
7777 return r
7878
8282 try:
8383 # TODO: Handle errors here.
8484 self._observers.pop().errback(f)
85 except:
85 except Exception:
8686 pass
8787
8888 if consumeErrors:
204204 try:
205205 with PreserveLoggingContext():
206206 yield current_defer
207 except:
207 except Exception:
208208 logger.exception("Unexpected exception in Linearizer")
209209
210210 logger.info("Acquired linearizer lock %r for key %r", self.name,
277277 if entry[0] >= self.max_count:
278278 new_defer = defer.Deferred()
279279 entry[1].append(new_defer)
280
281 logger.info("Waiting to acquire limiter lock for key %r", key)
280282 with PreserveLoggingContext():
281283 yield new_defer
284 logger.info("Acquired limiter lock for key %r", key)
285 else:
286 logger.info("Acquired uncontended limiter lock for key %r", key)
282287
283288 entry[0] += 1
284289
287292 try:
288293 yield
289294 finally:
295 logger.info("Releasing limiter lock for key %r", key)
296
290297 # We've finished executing so check if there are any things
291298 # blocked waiting to execute and start one of them
292299 entry[0] -= 1
293 try:
294 entry[1].pop(0).callback(None)
295 except IndexError:
296 # If nothing else is executing for this key then remove it
297 # from the map
298 if entry[0] == 0:
299 self.key_to_defer.pop(key, None)
300
301 if entry[1]:
302 next_def = entry[1].pop(0)
303
304 with PreserveLoggingContext():
305 next_def.callback(None)
306 elif entry[0] == 0:
307 # We were the last thing for this key: remove it from the
308 # map.
309 del self.key_to_defer[key]
300310
301311 defer.returnValue(_ctx_manager())
302312
340350
341351 # We wait for the latest writer to finish writing. We can safely ignore
342352 # any existing readers... as they're readers.
343 yield curr_writer
353 yield make_deferred_yieldable(curr_writer)
344354
345355 @contextmanager
346356 def _ctx_manager():
369379 curr_readers.clear()
370380 self.key_to_current_writer[key] = new_defer
371381
372 yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
382 yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
373383
374384 @contextmanager
375385 def _ctx_manager():
7474 self.cache = LruCache(
7575 max_size=max_entries, keylen=keylen, cache_type=cache_type,
7676 size_callback=(lambda d: len(d)) if iterable else None,
77 evicted_callback=self._on_evicted,
7778 )
7879
7980 self.name = name
8182 self.sequence = 0
8283 self.thread = None
8384 self.metrics = register_cache(name, self.cache)
85
86 def _on_evicted(self, evicted_count):
87 self.metrics.inc_evictions(evicted_count)
8488
8589 def check_thread(self):
8690 expected_thread = self.thread
131131 self._update_or_insert(key, value, known_absent)
132132
133133 def _update_or_insert(self, key, value, known_absent):
134 entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
134 # We pop and reinsert as we need to tell the cache the size may have
135 # changed
136
137 entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
135138 entry.value.update(value)
136139 entry.known_absent.update(known_absent)
140 self.cache[key] = entry
137141
138142 def _insert(self, key, value, known_absent):
139143 self.cache[key] = DictionaryEntry(True, known_absent, value)
7878 while self._max_len and len(self) > self._max_len:
7979 _key, value = self._cache.popitem(last=False)
8080 if self.iterable:
81 self._size_estimate -= len(value.value)
81 removed_len = len(value.value)
82 self.metrics.inc_evictions(removed_len)
83 self._size_estimate -= removed_len
84 else:
85 self.metrics.inc_evictions()
8286
8387 def __getitem__(self, key):
8488 try:
4848 Can also set callbacks on objects when getting/setting which are fired
4949 when that key gets invalidated/evicted.
5050 """
51 def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
51 def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
52 evicted_callback=None):
53 """
54 Args:
55 max_size (int):
56
57 keylen (int):
58
59 cache_type (type):
60 type of underlying cache to be used. Typically one of dict
61 or TreeCache.
62
63 size_callback (func(V) -> int | None):
64
65 evicted_callback (func(int)|None):
66 if not None, called on eviction with the size of the evicted
67 entry
68 """
5269 cache = cache_type()
5370 self.cache = cache # Used for introspection.
5471 list_root = _Node(None, None, None, None)
6077 def evict():
6178 while cache_len() > max_size:
6279 todelete = list_root.prev_node
63 delete_node(todelete)
80 evicted_len = delete_node(todelete)
6481 cache.pop(todelete.key, None)
82 if evicted_callback:
83 evicted_callback(evicted_len)
6584
6685 def synchronized(f):
6786 @wraps(f)
110129 prev_node.next_node = next_node
111130 next_node.prev_node = prev_node
112131
132 deleted_len = 1
113133 if size_callback:
114 cached_cache_len[0] -= size_callback(node.value)
134 deleted_len = size_callback(node.value)
135 cached_cache_len[0] -= deleted_len
115136
116137 for cb in node.callbacks:
117138 cb()
118139 node.callbacks.clear()
140 return deleted_len
119141
120142 @synchronized
121143 def cache_get(key, default=None, callbacks=[]):
131153 def cache_set(key, value, callbacks=[]):
132154 node = cache.get(key, None)
133155 if node is not None:
134 if value != node.value:
156 # We sometimes store large objects, e.g. dicts, which cause
157 # the inequality check to take a long time. So let's only do
158 # the check if we have some callbacks to call.
159 if node.callbacks and value != node.value:
135160 for cb in node.callbacks:
136161 cb()
137162 node.callbacks.clear()
138163
139 if size_callback:
140 cached_cache_len[0] -= size_callback(node.value)
141 cached_cache_len[0] += size_callback(value)
164 # We don't bother to protect this by value != node.value as
165 # generally size_callback will be cheap compared with equality
166 # checks. (For example, taking the size of two dicts is quicker
167 # than comparing them for equality.)
168 if size_callback:
169 cached_cache_len[0] -= size_callback(node.value)
170 cached_cache_len[0] += size_callback(value)
142171
143172 node.callbacks.update(callbacks)
144173
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import logging
16
1517 from twisted.internet import defer
1618
17 from synapse.util.logcontext import (
18 PreserveLoggingContext, preserve_context_over_fn
19 )
20
2119 from synapse.util import unwrapFirstError
22
23 import logging
24
20 from synapse.util.logcontext import PreserveLoggingContext
2521
2622 logger = logging.getLogger(__name__)
2723
2824
2925 def user_left_room(distributor, user, room_id):
30 return preserve_context_over_fn(
31 distributor.fire,
32 "user_left_room", user=user, room_id=room_id
33 )
26 with PreserveLoggingContext():
27 distributor.fire("user_left_room", user=user, room_id=room_id)
3428
3529
3630 def user_joined_room(distributor, user, room_id):
37 return preserve_context_over_fn(
38 distributor.fire,
39 "user_joined_room", user=user, room_id=room_id
40 )
31 with PreserveLoggingContext():
32 distributor.fire("user_joined_room", user=user, room_id=room_id)
4133
4234
4335 class Distributor(object):
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from twisted.internet import threads, reactor
16
17 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
18
19 import Queue
20
21
22 class BackgroundFileConsumer(object):
23 """A consumer that writes to a file like object. Supports both push
24 and pull producers
25
26 Args:
27 file_obj (file): The file like object to write to. Closed when
28 finished.
29 """
30
31 # For PushProducers pause if we have this many unwritten slices
32 _PAUSE_ON_QUEUE_SIZE = 5
33 # And resume once the size of the queue is less than this
34 _RESUME_ON_QUEUE_SIZE = 2
35
36 def __init__(self, file_obj):
37 self._file_obj = file_obj
38
39 # Producer we're registered with
40 self._producer = None
41
42 # True if PushProducer, false if PullProducer
43 self.streaming = False
44
45 # For PushProducers, indicates whether we've paused the producer and
46 # need to call resumeProducing before we get more data.
47 self._paused_producer = False
48
49 # Queue of slices of bytes to be written. When producer calls
50 # unregister a final None is sent.
51 self._bytes_queue = Queue.Queue()
52
53 # Deferred that is resolved when finished writing
54 self._finished_deferred = None
55
56 # If the _writer thread throws an exception it gets stored here.
57 self._write_exception = None
58
59 def registerProducer(self, producer, streaming):
60 """Part of IConsumer interface
61
62 Args:
63 producer (IProducer)
64 streaming (bool): True if push based producer, False if pull
65 based.
66 """
67 if self._producer:
68 raise Exception("registerProducer called twice")
69
70 self._producer = producer
71 self.streaming = streaming
72 self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
73 if not streaming:
74 self._producer.resumeProducing()
75
76 def unregisterProducer(self):
77 """Part of IProducer interface
78 """
79 self._producer = None
80 if not self._finished_deferred.called:
81 self._bytes_queue.put_nowait(None)
82
83 def write(self, bytes):
84 """Part of IProducer interface
85 """
86 if self._write_exception:
87 raise self._write_exception
88
89 if self._finished_deferred.called:
90 raise Exception("consumer has closed")
91
92 self._bytes_queue.put_nowait(bytes)
93
94 # If this is a PushProducer and the queue is getting behind
95 # then we pause the producer.
96 if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
97 self._paused_producer = True
98 self._producer.pauseProducing()
99
100 def _writer(self):
101 """This is run in a background thread to write to the file.
102 """
103 try:
104 while self._producer or not self._bytes_queue.empty():
105 # If we've paused the producer check if we should resume the
106 # producer.
107 if self._producer and self._paused_producer:
108 if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
109 reactor.callFromThread(self._resume_paused_producer)
110
111 bytes = self._bytes_queue.get()
112
113 # If we get a None (or empty list) then that's a signal used
114 # to indicate we should check if we should stop.
115 if bytes:
116 self._file_obj.write(bytes)
117
118 # If its a pull producer then we need to explicitly ask for
119 # more stuff.
120 if not self.streaming and self._producer:
121 reactor.callFromThread(self._producer.resumeProducing)
122 except Exception as e:
123 self._write_exception = e
124 raise
125 finally:
126 self._file_obj.close()
127
128 def wait(self):
129 """Returns a deferred that resolves when finished writing to file
130 """
131 return make_deferred_yieldable(self._finished_deferred)
132
133 def _resume_paused_producer(self):
134 """Gets called if we should resume producing after being paused
135 """
136 if self._paused_producer and self._producer:
137 self._paused_producer = False
138 self._producer.resumeProducing()
4141
4242 def get_thread_resource_usage():
4343 return resource.getrusage(RUSAGE_THREAD)
44 except:
44 except Exception:
4545 # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
4646 # won't track resource usage by returning None.
4747 def get_thread_resource_usage():
5151 class LoggingContext(object):
5252 """Additional context for log formatting. Contexts are scoped within a
5353 "with" block.
54
5455 Args:
5556 name (str): Name for the context for debugging.
5657 """
5758
5859 __slots__ = [
59 "previous_context", "name", "usage_start", "usage_end", "main_thread",
60 "__dict__", "tag", "alive",
60 "previous_context", "name", "ru_stime", "ru_utime",
61 "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
62 "usage_start", "usage_end",
63 "main_thread", "alive",
64 "request", "tag",
6165 ]
6266
6367 thread_local = threading.local()
8084 pass
8185
8286 def add_database_transaction(self, duration_ms):
87 pass
88
89 def add_database_scheduled(self, sched_ms):
8390 pass
8491
8592 def __nonzero__(self):
93100 self.ru_stime = 0.
94101 self.ru_utime = 0.
95102 self.db_txn_count = 0
96 self.db_txn_duration = 0.
103
104 # ms spent waiting for db txns, excluding scheduling time
105 self.db_txn_duration_ms = 0
106
107 # ms spent waiting for db txns to be scheduled
108 self.db_sched_duration_ms = 0
109
97110 self.usage_start = None
111 self.usage_end = None
98112 self.main_thread = threading.current_thread()
113 self.request = None
99114 self.tag = ""
100115 self.alive = True
101116
104119
105120 @classmethod
106121 def current_context(cls):
107 """Get the current logging context from thread local storage"""
122 """Get the current logging context from thread local storage
123
124 Returns:
125 LoggingContext: the current logging context
126 """
108127 return getattr(cls.thread_local, "current_context", cls.sentinel)
109128
110129 @classmethod
154173 self.alive = False
155174
156175 def copy_to(self, record):
157 """Copy fields from this context to the record"""
158 for key, value in self.__dict__.items():
159 setattr(record, key, value)
160
161 record.ru_utime, record.ru_stime = self.get_resource_usage()
176 """Copy logging fields from this context to a log record or
177 another LoggingContext
178 """
179
180 # 'request' is the only field we currently use in the logger, so that's
181 # all we need to copy
182 record.request = self.request
162183
163184 def start(self):
164185 if threading.current_thread() is not self.main_thread:
193214
194215 def add_database_transaction(self, duration_ms):
195216 self.db_txn_count += 1
196 self.db_txn_duration += duration_ms / 1000.
217 self.db_txn_duration_ms += duration_ms
218
219 def add_database_scheduled(self, sched_ms):
220 """Record a use of the database pool
221
222 Args:
223 sched_ms (int): number of milliseconds it took us to get a
224 connection
225 """
226 self.db_sched_duration_ms += sched_ms
197227
198228
199229 class LoggingContextFilter(logging.Filter):
260290 )
261291
262292
263 class _PreservingContextDeferred(defer.Deferred):
264 """A deferred that ensures that all callbacks and errbacks are called with
265 the given logging context.
266 """
267 def __init__(self, context):
268 self._log_context = context
269 defer.Deferred.__init__(self)
270
271 def addCallbacks(self, callback, errback=None,
272 callbackArgs=None, callbackKeywords=None,
273 errbackArgs=None, errbackKeywords=None):
274 callback = self._wrap_callback(callback)
275 errback = self._wrap_callback(errback)
276 return defer.Deferred.addCallbacks(
277 self, callback,
278 errback=errback,
279 callbackArgs=callbackArgs,
280 callbackKeywords=callbackKeywords,
281 errbackArgs=errbackArgs,
282 errbackKeywords=errbackKeywords,
283 )
284
285 def _wrap_callback(self, f):
286 def g(res, *args, **kwargs):
287 with PreserveLoggingContext(self._log_context):
288 res = f(res, *args, **kwargs)
289 return res
290 return g
291
292
293 def preserve_context_over_fn(fn, *args, **kwargs):
294 """Takes a function and invokes it with the given arguments, but removes
295 and restores the current logging context while doing so.
296
297 If the result is a deferred, call preserve_context_over_deferred before
298 returning it.
299 """
300 with PreserveLoggingContext():
301 res = fn(*args, **kwargs)
302
303 if isinstance(res, defer.Deferred):
304 return preserve_context_over_deferred(res)
305 else:
306 return res
307
308
309 def preserve_context_over_deferred(deferred, context=None):
310 """Given a deferred wrap it such that any callbacks added later to it will
311 be invoked with the current context.
312
313 Deprecated: this almost certainly doesn't do want you want, ie make
314 the deferred follow the synapse logcontext rules: try
315 ``make_deferred_yieldable`` instead.
316 """
317 if context is None:
318 context = LoggingContext.current_context()
319 d = _PreservingContextDeferred(context)
320 deferred.chainDeferred(d)
321 return d
322
323
324293 def preserve_fn(f):
325 """Wraps a function, to ensure that the current context is restored after
294 """Function decorator which wraps the function with run_in_background"""
295 def g(*args, **kwargs):
296 return run_in_background(f, *args, **kwargs)
297 return g
298
299
300 def run_in_background(f, *args, **kwargs):
301 """Calls a function, ensuring that the current context is restored after
326302 return from the function, and that the sentinel context is set once the
327303 deferred returned by the funtion completes.
328304
329305 Useful for wrapping functions that return a deferred which you don't yield
330306 on.
331307 """
332 def reset_context(result):
333 LoggingContext.set_current_context(LoggingContext.sentinel)
334 return result
335
336 def g(*args, **kwargs):
337 current = LoggingContext.current_context()
338 res = f(*args, **kwargs)
339 if isinstance(res, defer.Deferred) and not res.called:
340 # The function will have reset the context before returning, so
341 # we need to restore it now.
342 LoggingContext.set_current_context(current)
343
344 # The original context will be restored when the deferred
345 # completes, but there is nothing waiting for it, so it will
346 # get leaked into the reactor or some other function which
347 # wasn't expecting it. We therefore need to reset the context
348 # here.
349 #
350 # (If this feels asymmetric, consider it this way: we are
351 # effectively forking a new thread of execution. We are
352 # probably currently within a ``with LoggingContext()`` block,
353 # which is supposed to have a single entry and exit point. But
354 # by spawning off another deferred, we are effectively
355 # adding a new exit point.)
356 res.addBoth(reset_context)
357 return res
358 return g
359
360
361 @defer.inlineCallbacks
308 current = LoggingContext.current_context()
309 res = f(*args, **kwargs)
310 if isinstance(res, defer.Deferred) and not res.called:
311 # The function will have reset the context before returning, so
312 # we need to restore it now.
313 LoggingContext.set_current_context(current)
314
315 # The original context will be restored when the deferred
316 # completes, but there is nothing waiting for it, so it will
317 # get leaked into the reactor or some other function which
318 # wasn't expecting it. We therefore need to reset the context
319 # here.
320 #
321 # (If this feels asymmetric, consider it this way: we are
322 # effectively forking a new thread of execution. We are
323 # probably currently within a ``with LoggingContext()`` block,
324 # which is supposed to have a single entry and exit point. But
325 # by spawning off another deferred, we are effectively
326 # adding a new exit point.)
327 res.addBoth(_set_context_cb, LoggingContext.sentinel)
328 return res
329
330
362331 def make_deferred_yieldable(deferred):
363332 """Given a deferred, make it follow the Synapse logcontext rules:
364333
372341
373342 (This is more-or-less the opposite operation to preserve_fn.)
374343 """
375 with PreserveLoggingContext():
376 r = yield deferred
377 defer.returnValue(r)
344 if isinstance(deferred, defer.Deferred) and not deferred.called:
345 prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
346 deferred.addBoth(_set_context_cb, prev_context)
347 return deferred
348
349
350 def _set_context_cb(result, context):
351 """A callback function which just sets the logging context"""
352 LoggingContext.set_current_context(context)
353 return result
378354
379355
380356 # modules to ignore in `logcontext_tracer`
2626
2727 metrics = synapse.metrics.get_metrics_for(__name__)
2828
29 block_timer = metrics.register_distribution(
30 "block_timer",
31 labels=["block_name"]
29 # total number of times we have hit this block
30 block_counter = metrics.register_counter(
31 "block_count",
32 labels=["block_name"],
33 alternative_names=(
34 # the following are all deprecated aliases for the same metric
35 metrics.name_prefix + x for x in (
36 "_block_timer:count",
37 "_block_ru_utime:count",
38 "_block_ru_stime:count",
39 "_block_db_txn_count:count",
40 "_block_db_txn_duration:count",
41 )
42 )
3243 )
3344
34 block_ru_utime = metrics.register_distribution(
35 "block_ru_utime", labels=["block_name"]
45 block_timer = metrics.register_counter(
46 "block_time_seconds",
47 labels=["block_name"],
48 alternative_names=(
49 metrics.name_prefix + "_block_timer:total",
50 ),
3651 )
3752
38 block_ru_stime = metrics.register_distribution(
39 "block_ru_stime", labels=["block_name"]
53 block_ru_utime = metrics.register_counter(
54 "block_ru_utime_seconds", labels=["block_name"],
55 alternative_names=(
56 metrics.name_prefix + "_block_ru_utime:total",
57 ),
4058 )
4159
42 block_db_txn_count = metrics.register_distribution(
43 "block_db_txn_count", labels=["block_name"]
60 block_ru_stime = metrics.register_counter(
61 "block_ru_stime_seconds", labels=["block_name"],
62 alternative_names=(
63 metrics.name_prefix + "_block_ru_stime:total",
64 ),
4465 )
4566
46 block_db_txn_duration = metrics.register_distribution(
47 "block_db_txn_duration", labels=["block_name"]
67 block_db_txn_count = metrics.register_counter(
68 "block_db_txn_count", labels=["block_name"],
69 alternative_names=(
70 metrics.name_prefix + "_block_db_txn_count:total",
71 ),
72 )
73
74 # seconds spent waiting for db txns, excluding scheduling time, in this block
75 block_db_txn_duration = metrics.register_counter(
76 "block_db_txn_duration_seconds", labels=["block_name"],
77 alternative_names=(
78 metrics.name_prefix + "_block_db_txn_duration:total",
79 ),
80 )
81
82 # seconds spent waiting for a db connection, in this block
83 block_db_sched_duration = metrics.register_counter(
84 "block_db_sched_duration_seconds", labels=["block_name"],
4885 )
4986
5087
63100 class Measure(object):
64101 __slots__ = [
65102 "clock", "name", "start_context", "start", "new_context", "ru_utime",
66 "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
103 "ru_stime",
104 "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
105 "created_context",
67106 ]
68107
69108 def __init__(self, clock, name):
83122
84123 self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
85124 self.db_txn_count = self.start_context.db_txn_count
86 self.db_txn_duration = self.start_context.db_txn_duration
125 self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
126 self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
87127
88128 def __exit__(self, exc_type, exc_val, exc_tb):
89129 if isinstance(exc_type, Exception) or not self.start_context:
90130 return
91131
92132 duration = self.clock.time_msec() - self.start
133
134 block_counter.inc(self.name)
93135 block_timer.inc_by(duration, self.name)
94136
95137 context = LoggingContext.current_context()
113155 context.db_txn_count - self.db_txn_count, self.name
114156 )
115157 block_db_txn_duration.inc_by(
116 context.db_txn_duration - self.db_txn_duration, self.name
158 (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
159 self.name
160 )
161 block_db_sched_duration.inc_by(
162 (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
163 self.name
117164 )
118165
119166 if self.created_context:
2525
2626 class NotRetryingDestination(Exception):
2727 def __init__(self, retry_last_ts, retry_interval, destination):
28 """Raised by the limiter (and federation client) to indicate that we are
29 are deliberately not attempting to contact a given server.
30
31 Args:
32 retry_last_ts (int): the unix ts in milliseconds of our last attempt
33 to contact the server. 0 indicates that the last attempt was
34 successful or that we've never actually attempted to connect.
35 retry_interval (int): the time in milliseconds to wait until the next
36 attempt.
37 destination (str): the domain in question
38 """
39
2840 msg = "Not retrying server %s." % (destination,)
2941 super(NotRetryingDestination, self).__init__(msg)
3042
188200 yield self.store.set_destination_retry_timings(
189201 self.destination, retry_last_ts, self.retry_interval
190202 )
191 except:
203 except Exception:
192204 logger.exception(
193205 "Failed to store set_destination_retry_timings",
194206 )
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import logging
16 import re
17
18 logger = logging.getLogger(__name__)
19
20
21 def check_3pid_allowed(hs, medium, address):
22 """Checks whether a given format of 3PID is allowed to be used on this HS
23
24 Args:
25 hs (synapse.server.HomeServer): server
26 medium (str): 3pid medium - e.g. email, msisdn
27 address (str): address within that medium (e.g. "wotan@matrix.org")
28 msisdns need to first have been canonicalised
29 Returns:
30 bool: whether the 3PID medium/address is allowed to be added to this HS
31 """
32
33 if hs.config.allowed_local_3pids:
34 for constraint in hs.config.allowed_local_3pids:
35 logger.debug(
36 "Checking 3PID %s (%s) against %s (%s)",
37 address, medium, constraint['pattern'], constraint['medium'],
38 )
39 if (
40 medium == constraint['medium'] and
41 re.match(constraint['pattern'], address)
42 ):
43 return True
44 else:
45 return True
46
47 return False
9090 return ret
9191
9292 def __len__(self):
93 l = 0
94 for entry in self.entries:
95 l += len(entry.queue)
96 return l
93 return sum(len(entry.queue) for entry in self.entries)
1616
1717 from synapse.api.constants import Membership, EventTypes
1818
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
19 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
2020
2121 import logging
2222
5757 always_include_ids (set(event_id)): set of event ids to specifically
5858 include (unless sender is ignored)
5959 """
60 forgotten = yield preserve_context_over_deferred(defer.gatherResults([
60 forgotten = yield make_deferred_yieldable(defer.gatherResults([
6161 defer.maybeDeferred(
6262 preserve_fn(store.who_forgot_in_room),
6363 room_id,
3535 id="unique_identifier",
3636 url="some_url",
3737 token="some_token",
38 hostname="matrix.org", # only used by get_groups_for_user
3839 namespaces={
3940 ApplicationService.NS_USERS: [],
4041 ApplicationService.NS_ROOMS: [],
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1415 import os.path
16 import re
1517 import shutil
1618 import tempfile
19
1720 from synapse.config.homeserver import HomeServerConfig
1821 from tests import unittest
1922
2225
2326 def setUp(self):
2427 self.dir = tempfile.mkdtemp()
25 print self.dir
2628 self.file = os.path.join(self.dir, "homeserver.yaml")
2729
2830 def tearDown(self):
4749 ]),
4850 set(os.listdir(self.dir))
4951 )
52
53 self.assert_log_filename_is(
54 os.path.join(self.dir, "lemurs.win.log.config"),
55 os.path.join(os.getcwd(), "homeserver.log"),
56 )
57
58 def assert_log_filename_is(self, log_config_file, expected):
59 with open(log_config_file) as f:
60 config = f.read()
61 # find the 'filename' line
62 matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
63 self.assertEqual(1, len(matches))
64 self.assertEqual(matches[0], expected)
6767
6868 def check_context(self, _, expected):
6969 self.assertEquals(
70 getattr(LoggingContext.current_context(), "test_key", None),
70 getattr(LoggingContext.current_context(), "request", None),
7171 expected
7272 )
7373
8181 lookup_2_deferred = defer.Deferred()
8282
8383 with LoggingContext("one") as context_one:
84 context_one.test_key = "one"
84 context_one.request = "one"
8585
8686 wait_1_deferred = kr.wait_for_previous_lookups(
8787 ["server1"],
9595 wait_1_deferred.addBoth(self.check_context, "one")
9696
9797 with LoggingContext("two") as context_two:
98 context_two.test_key = "two"
98 context_two.request = "two"
9999
100100 # set off another wait. It should block because the first lookup
101101 # hasn't yet completed.
136136 @defer.inlineCallbacks
137137 def get_perspectives(**kwargs):
138138 self.assertEquals(
139 LoggingContext.current_context().test_key, "11",
139 LoggingContext.current_context().request, "11",
140140 )
141141 with logcontext.PreserveLoggingContext():
142142 yield persp_deferred
144144 self.http_client.post_json.side_effect = get_perspectives
145145
146146 with LoggingContext("11") as context_11:
147 context_11.test_key = "11"
147 context_11.request = "11"
148148
149149 # start off a first set of lookups
150150 res_deferreds = kr.verify_json_objects_for_server(
166166
167167 # wait a tick for it to send the request to the perspectives server
168168 # (it first tries the datastore)
169 yield async.sleep(0.005)
169 yield async.sleep(1) # XXX find out why this takes so long!
170170 self.http_client.post_json.assert_called_once()
171171
172172 self.assertIs(LoggingContext.current_context(), context_11)
173173
174174 context_12 = LoggingContext("12")
175 context_12.test_key = "12"
175 context_12.request = "12"
176176 with logcontext.PreserveLoggingContext(context_12):
177177 # a second request for a server with outstanding requests
178178 # should block rather than start a second call
182182 res_deferreds_2 = kr.verify_json_objects_for_server(
183183 [("server10", json1)],
184184 )
185 yield async.sleep(0.005)
185 yield async.sleep(01)
186186 self.http_client.post_json.assert_not_called()
187187 res_deferreds_2[0].addBoth(self.check_context, None)
188188
210210 sentinel_context = LoggingContext.current_context()
211211
212212 with LoggingContext("one") as context_one:
213 context_one.test_key = "one"
213 context_one.request = "one"
214214
215215 defer = kr.verify_json_for_server("server9", {})
216216 try:
5252 type="m.room.message",
5353 room_id="!foo:bar"
5454 )
55 self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
55 self.mock_store.get_new_events_for_appservice.side_effect = [
56 (0, [event]),
57 (0, [])
58 ]
5659 self.mock_as_api.push = Mock()
5760 yield self.handler.notify_interested_services(0)
5861 self.mock_scheduler.submit_event_for_as.assert_called_once_with(
7477 )
7578 self.mock_as_api.push = Mock()
7679 self.mock_as_api.query_user = Mock()
77 self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
80 self.mock_store.get_new_events_for_appservice.side_effect = [
81 (0, [event]),
82 (0, [])
83 ]
7884 yield self.handler.notify_interested_services(0)
7985 self.mock_as_api.query_user.assert_called_once_with(
8086 services[0], user_id
97103 )
98104 self.mock_as_api.push = Mock()
99105 self.mock_as_api.query_user = Mock()
100 self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
106 self.mock_store.get_new_events_for_appservice.side_effect = [
107 (0, [event]),
108 (0, [])
109 ]
101110 yield self.handler.notify_interested_services(0)
102111 self.assertFalse(
103112 self.mock_as_api.query_user.called,
3434
3535 @defer.inlineCallbacks
3636 def setUp(self):
37 self.mock_federation = Mock(spec=[
38 "make_query",
39 "register_edu_handler",
40 ])
37 self.mock_federation = Mock()
38 self.mock_registry = Mock()
4139
4240 self.query_handlers = {}
4341
4442 def register_query_handler(query_type, handler):
4543 self.query_handlers[query_type] = handler
46 self.mock_federation.register_query_handler = register_query_handler
44 self.mock_registry.register_query_handler = register_query_handler
4745
4846 hs = yield setup_test_homeserver(
4947 http_client=None,
5048 resource_for_federation=Mock(),
51 replication_layer=self.mock_federation,
49 federation_client=self.mock_federation,
50 federation_registry=self.mock_registry,
5251 )
5352 hs.handlers = DirectoryHandlers(hs)
5453
3333 def setUp(self):
3434 self.hs = yield utils.setup_test_homeserver(
3535 handlers=None,
36 replication_layer=mock.Mock(),
36 federation_client=mock.Mock(),
3737 )
3838 self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
3939
142142 except errors.SynapseError:
143143 pass
144144
145 @unittest.DEBUG
146145 @defer.inlineCallbacks
147146 def test_claim_one_time_key(self):
148147 local_user = "@boris:" + self.hs.hostname
3636
3737 @defer.inlineCallbacks
3838 def setUp(self):
39 self.mock_federation = Mock(spec=[
40 "make_query",
41 "register_edu_handler",
42 ])
39 self.mock_federation = Mock()
40 self.mock_registry = Mock()
4341
4442 self.query_handlers = {}
4543
4644 def register_query_handler(query_type, handler):
4745 self.query_handlers[query_type] = handler
4846
49 self.mock_federation.register_query_handler = register_query_handler
47 self.mock_registry.register_query_handler = register_query_handler
5048
5149 hs = yield setup_test_homeserver(
5250 http_client=None,
5351 handlers=None,
5452 resource_for_federation=Mock(),
55 replication_layer=self.mock_federation,
53 federation_client=self.mock_federation,
54 federation_server=Mock(),
55 federation_registry=self.mock_registry,
5656 ratelimiter=NonCallableMock(spec_set=[
5757 "send_message",
5858 ])
5757
5858 self.mock_federation_resource = MockHttpResource()
5959
60 mock_notifier = Mock(spec=["on_new_event"])
60 mock_notifier = Mock()
6161 self.on_new_event = mock_notifier.on_new_event
6262
6363 self.auth = Mock(spec=[])
7575 "set_received_txn_response",
7676 "get_destination_retry_timings",
7777 "get_devices_by_remote",
78 # Bits that user_directory needs
79 "get_user_directory_stream_pos",
80 "get_current_state_deltas",
7881 ]),
7982 state_handler=self.state_handler,
80 handlers=None,
83 handlers=Mock(),
8184 notifier=mock_notifier,
8285 resource_for_client=Mock(),
8386 resource_for_federation=self.mock_federation_resource,
120123 def get_current_user_in_room(room_id):
121124 return set(str(u) for u in self.room_members)
122125 self.state_handler.get_current_user_in_room = get_current_user_in_room
126
127 self.datastore.get_user_directory_stream_pos.return_value = (
128 # we deliberately return a non-None stream pos to avoid doing an initial_spam
129 defer.succeed(1)
130 )
131
132 self.datastore.get_current_state_deltas.return_value = (
133 None
134 )
123135
124136 self.auth.check_joined_room = check_joined_room
125137
140140 'cache:hits{name="cache_name"} 0',
141141 'cache:total{name="cache_name"} 0',
142142 'cache:size{name="cache_name"} 0',
143 'cache:evicted_size{name="cache_name"} 0',
143144 ])
144145
145146 metric.inc_misses()
149150 'cache:hits{name="cache_name"} 0',
150151 'cache:total{name="cache_name"} 1',
151152 'cache:size{name="cache_name"} 1',
153 'cache:evicted_size{name="cache_name"} 0',
152154 ])
153155
154156 metric.inc_hits()
157159 'cache:hits{name="cache_name"} 1',
158160 'cache:total{name="cache_name"} 2',
159161 'cache:size{name="cache_name"} 1',
162 'cache:evicted_size{name="cache_name"} 0',
160163 ])
164
165 metric.inc_evictions(2)
166
167 self.assertEquals(metric.render(), [
168 'cache:hits{name="cache_name"} 1',
169 'cache:total{name="cache_name"} 2',
170 'cache:size{name="cache_name"} 1',
171 'cache:evicted_size{name="cache_name"} 2',
172 ])
1414 from twisted.internet import defer, reactor
1515 from tests import unittest
1616
17 import tempfile
18
1719 from mock import Mock, NonCallableMock
1820 from tests.utils import setup_test_homeserver
1921 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
2830 self.hs = yield setup_test_homeserver(
2931 "blue",
3032 http_client=None,
31 replication_layer=Mock(),
33 federation_client=Mock(),
3234 ratelimiter=NonCallableMock(spec_set=[
3335 "send_message",
3436 ]),
4042 self.event_id = 0
4143
4244 server_factory = ReplicationStreamProtocolFactory(self.hs)
43 listener = reactor.listenUNIX("\0xxx", server_factory)
45 # XXX: mktemp is unsafe and should never be used. but we're just a test.
46 path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
47 listener = reactor.listenUNIX(path, server_factory)
4448 self.addCleanup(listener.stopListening)
4549 self.streamer = server_factory.streamer
4650
4852 client_factory = ReplicationClientFactory(
4953 self.hs, "client_name", self.replication_handler
5054 )
51 client_connector = reactor.connectUNIX("\0xxx", client_factory)
55 client_connector = reactor.connectUNIX(path, client_factory)
5256 self.addCleanup(client_factory.stopTrying)
5357 self.addCleanup(client_connector.disconnect)
5458
225225 context = EventContext()
226226 context.current_state_ids = state_ids
227227 context.prev_state_ids = state_ids
228 elif not backfill:
228 else:
229229 state_handler = self.hs.get_state_handler()
230230 context = yield state_handler.compute_event_context(event)
231 else:
232 context = EventContext()
233
234 context.push_actions = push_actions
231
232 yield self.master_store.add_push_actions_to_staging(
233 event.event_id, {
234 user_id: actions
235 for user_id, actions in push_actions
236 },
237 )
235238
236239 ordering = None
237240 if backfill:
113113
114114 hs = yield setup_test_homeserver(
115115 http_client=None,
116 replication_layer=Mock(),
116 federation_client=Mock(),
117117 ratelimiter=NonCallableMock(spec_set=[
118118 "send_message",
119119 ]),
4444 http_client=None,
4545 resource_for_client=self.mock_resource,
4646 federation=Mock(),
47 replication_layer=Mock(),
47 federation_client=Mock(),
4848 profile_handler=self.mock_handler
4949 )
5050
4545 hs = yield setup_test_homeserver(
4646 "red",
4747 http_client=None,
48 replication_layer=Mock(),
48 federation_client=Mock(),
4949 ratelimiter=NonCallableMock(spec_set=["send_message"]),
5050 )
5151 self.ratelimiter = hs.get_ratelimiter()
408408 hs = yield setup_test_homeserver(
409409 "red",
410410 http_client=None,
411 replication_layer=Mock(),
411 federation_client=Mock(),
412412 ratelimiter=NonCallableMock(spec_set=["send_message"]),
413413 )
414414 self.ratelimiter = hs.get_ratelimiter()
492492 hs = yield setup_test_homeserver(
493493 "red",
494494 http_client=None,
495 replication_layer=Mock(),
495 federation_client=Mock(),
496496 ratelimiter=NonCallableMock(spec_set=["send_message"]),
497497 )
498498 self.ratelimiter = hs.get_ratelimiter()
514514
515515 synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
516516
517 def tearDown(self):
518 pass
519
520517 @defer.inlineCallbacks
521518 def test_post_room_no_keys(self):
522519 # POST with no config keys, expect new room id
584581 hs = yield setup_test_homeserver(
585582 "red",
586583 http_client=None,
587 replication_layer=Mock(),
584 federation_client=Mock(),
588585 ratelimiter=NonCallableMock(spec_set=["send_message"]),
589586 )
590587 self.ratelimiter = hs.get_ratelimiter()
699696 hs = yield setup_test_homeserver(
700697 "red",
701698 http_client=None,
702 replication_layer=Mock(),
699 federation_client=Mock(),
703700 ratelimiter=NonCallableMock(spec_set=["send_message"]),
704701 )
705702 self.ratelimiter = hs.get_ratelimiter()
831828 hs = yield setup_test_homeserver(
832829 "red",
833830 http_client=None,
834 replication_layer=Mock(),
831 federation_client=Mock(),
835832 ratelimiter=NonCallableMock(spec_set=["send_message"]),
836833 )
837834 self.ratelimiter = hs.get_ratelimiter()
931928 hs = yield setup_test_homeserver(
932929 "red",
933930 http_client=None,
934 replication_layer=Mock(),
931 federation_client=Mock(),
935932 ratelimiter=NonCallableMock(spec_set=[
936933 "send_message",
937934 ]),
10051002 hs = yield setup_test_homeserver(
10061003 "red",
10071004 http_client=None,
1008 replication_layer=Mock(),
1005 federation_client=Mock(),
10091006 ratelimiter=NonCallableMock(spec_set=["send_message"]),
10101007 )
10111008 self.ratelimiter = hs.get_ratelimiter()
4646 "red",
4747 clock=self.clock,
4848 http_client=None,
49 replication_layer=Mock(),
49 federation_client=Mock(),
5050 ratelimiter=NonCallableMock(spec_set=[
5151 "send_message",
5252 ]),
9494 else:
9595 if remotedomains is not None:
9696 remotedomains.add(member.domain)
97 hs.get_handlers().room_member_handler.fetch_room_distributions_into = (
97 hs.get_room_member_handler().fetch_room_distributions_into = (
9898 fetch_room_distributions_into
9999 )
100100
0 from twisted.python import failure
1
02 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
1 from synapse.api.errors import SynapseError
3 from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
24 from twisted.internet import defer
35 from mock import Mock
46 from tests import unittest
2325 side_effect=lambda x: self.appservice)
2426 )
2527
26 self.auth_result = (False, None, None, None)
28 self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
2729 self.auth_handler = Mock(
2830 check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
2931 get_session_data=Mock(return_value=None)
4648 self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
4749 self.hs.get_device_handler = Mock(return_value=self.device_handler)
4850 self.hs.config.enable_registration = True
51 self.hs.config.registrations_require_3pid = []
4952 self.hs.config.auto_join_rooms = []
5053
5154 # init the thing we're testing
8588 self.request.args = {
8689 "access_token": "i_am_an_app_service"
8790 }
91
8892 self.request_data = json.dumps({
8993 "username": "kermit"
9094 })
119123 "device_id": device_id,
120124 })
121125 self.registration_handler.check_username = Mock(return_value=True)
122 self.auth_result = (True, None, {
126 self.auth_result = (None, {
123127 "username": "kermit",
124128 "password": "monkey"
125129 }, None)
149153 "password": "monkey"
150154 })
151155 self.registration_handler.check_username = Mock(return_value=True)
152 self.auth_result = (True, None, {
156 self.auth_result = (None, {
153157 "username": "kermit",
154158 "password": "monkey"
155159 }, None)
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15
16 from twisted.internet import defer
17
18 from synapse.rest.media.v1._base import FileInfo
19 from synapse.rest.media.v1.media_storage import MediaStorage
20 from synapse.rest.media.v1.filepath import MediaFilePaths
21 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
22
23 from mock import Mock
24
25 from tests import unittest
26
27 import os
28 import shutil
29 import tempfile
30
31
32 class MediaStorageTests(unittest.TestCase):
33 def setUp(self):
34 self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
35
36 self.primary_base_path = os.path.join(self.test_dir, "primary")
37 self.secondary_base_path = os.path.join(self.test_dir, "secondary")
38
39 hs = Mock()
40 hs.config.media_store_path = self.primary_base_path
41
42 storage_providers = [FileStorageProviderBackend(
43 hs, self.secondary_base_path
44 )]
45
46 self.filepaths = MediaFilePaths(self.primary_base_path)
47 self.media_storage = MediaStorage(
48 self.primary_base_path, self.filepaths, storage_providers,
49 )
50
51 def tearDown(self):
52 shutil.rmtree(self.test_dir)
53
54 @defer.inlineCallbacks
55 def test_ensure_media_is_in_local_cache(self):
56 media_id = "some_media_id"
57 test_body = "Test\n"
58
59 # First we create a file that is in a storage provider but not in the
60 # local primary media store
61 rel_path = self.filepaths.local_media_filepath_rel(media_id)
62 secondary_path = os.path.join(self.secondary_base_path, rel_path)
63
64 os.makedirs(os.path.dirname(secondary_path))
65
66 with open(secondary_path, "w") as f:
67 f.write(test_body)
68
69 # Now we run ensure_media_is_in_local_cache, which should copy the file
70 # to the local cache.
71 file_info = FileInfo(None, media_id)
72 local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
73
74 self.assertTrue(os.path.exists(local_path))
75
76 # Asserts the file is under the expected local cache directory
77 self.assertEquals(
78 os.path.commonprefix([self.primary_base_path, local_path]),
79 self.primary_base_path,
80 )
81
82 with open(local_path) as f:
83 body = f.read()
84
85 self.assertEqual(test_body, body)
4141 hs = yield setup_test_homeserver(
4242 config=config,
4343 federation_sender=Mock(),
44 replication_layer=Mock(),
44 federation_client=Mock(),
4545 )
4646
4747 self.as_token = "token1"
5757 self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
5858 self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
5959 # must be done after inserts
60 self.store = ApplicationServiceStore(hs)
60 self.store = ApplicationServiceStore(None, hs)
6161
6262 def tearDown(self):
6363 # TODO: suboptimal that we need to create files for tests!
6464 for f in self.as_yaml_files:
6565 try:
6666 os.remove(f)
67 except:
67 except Exception:
6868 pass
6969
7070 def _add_appservice(self, as_token, id, url, hs_token, sender):
118118 hs = yield setup_test_homeserver(
119119 config=config,
120120 federation_sender=Mock(),
121 replication_layer=Mock(),
121 federation_client=Mock(),
122122 )
123123 self.db_pool = hs.get_db_pool()
124124
149149
150150 self.as_yaml_files = []
151151
152 self.store = TestTransactionStore(hs)
152 self.store = TestTransactionStore(None, hs)
153153
154154 def _add_service(self, url, as_token, id):
155155 as_yaml = dict(url=url, as_token=as_token, hs_token="something",
419419 class TestTransactionStore(ApplicationServiceTransactionStore,
420420 ApplicationServiceStore):
421421
422 def __init__(self, hs):
423 super(TestTransactionStore, self).__init__(hs)
422 def __init__(self, db_conn, hs):
423 super(TestTransactionStore, self).__init__(db_conn, hs)
424424
425425
426426 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
454454 config=config,
455455 datastore=Mock(),
456456 federation_sender=Mock(),
457 replication_layer=Mock(),
458 )
459
460 ApplicationServiceStore(hs)
457 federation_client=Mock(),
458 )
459
460 ApplicationServiceStore(None, hs)
461461
462462 @defer.inlineCallbacks
463463 def test_duplicate_ids(self):
472472 config=config,
473473 datastore=Mock(),
474474 federation_sender=Mock(),
475 replication_layer=Mock(),
475 federation_client=Mock(),
476476 )
477477
478478 with self.assertRaises(ConfigError) as cm:
479 ApplicationServiceStore(hs)
479 ApplicationServiceStore(None, hs)
480480
481481 e = cm.exception
482482 self.assertIn(f1, e.message)
496496 config=config,
497497 datastore=Mock(),
498498 federation_sender=Mock(),
499 replication_layer=Mock(),
499 federation_client=Mock(),
500500 )
501501
502502 with self.assertRaises(ConfigError) as cm:
503 ApplicationServiceStore(hs)
503 ApplicationServiceStore(None, hs)
504504
505505 e = cm.exception
506506 self.assertIn(f1, e.message)
5555 database_engine=create_engine(config.database_config),
5656 )
5757
58 self.datastore = SQLBaseStore(hs)
58 self.datastore = SQLBaseStore(None, hs)
5959
6060 @defer.inlineCallbacks
6161 def test_insert_1col(self):
2828 def setUp(self):
2929 hs = yield setup_test_homeserver()
3030
31 self.store = DirectoryStore(hs)
31 self.store = DirectoryStore(None, hs)
3232
3333 self.room = RoomID.from_string("!abcde:test")
3434 self.alias = RoomAlias.from_string("#my-room:test")
6161 {"notify_count": noitf_count, "highlight_count": highlight_count}
6262 )
6363
64 @defer.inlineCallbacks
6465 def _inject_actions(stream, action):
6566 event = Mock()
6667 event.room_id = room_id
6869 event.internal_metadata.stream_ordering = stream
6970 event.depth = stream
7071
71 tuples = [(user_id, action)]
72
73 return self.store.runInteraction(
72 yield self.store.add_push_actions_to_staging(
73 event.event_id, {user_id: action},
74 )
75 yield self.store.runInteraction(
7476 "", self.store._set_push_actions_for_event_and_users_txn,
75 event, tuples
77 [(event, None)], [(event, None)],
7678 )
7779
7880 def _rotate(stream):
124126 yield _assert_counts(1, 1)
125127 yield _rotate(10)
126128 yield _assert_counts(1, 1)
129
130 @tests.unittest.DEBUG
131 @defer.inlineCallbacks
132 def test_find_first_stream_ordering_after_ts(self):
133 def add_event(so, ts):
134 return self.store._simple_insert("events", {
135 "stream_ordering": so,
136 "received_ts": ts,
137 "event_id": "event%i" % so,
138 "type": "",
139 "room_id": "",
140 "content": "",
141 "processed": True,
142 "outlier": False,
143 "topological_ordering": 0,
144 "depth": 0,
145 })
146
147 # start with the base case where there are no events in the table
148 r = yield self.store.find_first_stream_ordering_after_ts(11)
149 self.assertEqual(r, 0)
150
151 # now with one event
152 yield add_event(2, 10)
153 r = yield self.store.find_first_stream_ordering_after_ts(9)
154 self.assertEqual(r, 2)
155 r = yield self.store.find_first_stream_ordering_after_ts(10)
156 self.assertEqual(r, 2)
157 r = yield self.store.find_first_stream_ordering_after_ts(11)
158 self.assertEqual(r, 3)
159
160 # add a bunch of dummy events to the events table
161 for (stream_ordering, ts) in (
162 (3, 110),
163 (4, 120),
164 (5, 120),
165 (10, 130),
166 (20, 140),
167 ):
168 yield add_event(stream_ordering, ts)
169
170 r = yield self.store.find_first_stream_ordering_after_ts(110)
171 self.assertEqual(r, 3,
172 "First event after 110ms should be 3, was %i" % r)
173
174 # 4 and 5 are both after 120: we want 4 rather than 5
175 r = yield self.store.find_first_stream_ordering_after_ts(120)
176 self.assertEqual(r, 4,
177 "First event after 120ms should be 4, was %i" % r)
178
179 r = yield self.store.find_first_stream_ordering_after_ts(129)
180 self.assertEqual(r, 10,
181 "First event after 129ms should be 10, was %i" % r)
182
183 # check we can get the last event
184 r = yield self.store.find_first_stream_ordering_after_ts(140)
185 self.assertEqual(r, 20,
186 "First event after 14ms should be 20, was %i" % r)
187
188 # off the end
189 r = yield self.store.find_first_stream_ordering_after_ts(160)
190 self.assertEqual(r, 21)
191
192 # check we can find an event at ordering zero
193 yield add_event(0, 5)
194 r = yield self.store.find_first_stream_ordering_after_ts(1)
195 self.assertEqual(r, 0)
2828 def setUp(self):
2929 hs = yield setup_test_homeserver(clock=MockClock())
3030
31 self.store = PresenceStore(hs)
31 self.store = PresenceStore(None, hs)
3232
3333 self.u_apple = UserID.from_string("@apple:test")
3434 self.u_banana = UserID.from_string("@banana:test")
2828 def setUp(self):
2929 hs = yield setup_test_homeserver()
3030
31 self.store = ProfileStore(hs)
31 self.store = ProfileStore(None, hs)
3232
3333 self.u_frank = UserID.from_string("@frank:test")
3434
3535
3636 self.store = hs.get_datastore()
3737 self.event_builder_factory = hs.get_event_builder_factory()
38 self.handlers = hs.get_handlers()
39 self.message_handler = self.handlers.message_handler
38 self.event_creation_handler = hs.get_event_creation_handler()
4039
4140 self.u_alice = UserID.from_string("@alice:test")
4241 self.u_bob = UserID.from_string("@bob:test")
5857 "content": content,
5958 })
6059
61 event, context = yield self.message_handler._create_new_client_event(
60 event, context = yield self.event_creation_handler.create_new_client_event(
6261 builder
6362 )
6463
7877 "content": {"body": body, "msgtype": u"message"},
7978 })
8079
81 event, context = yield self.message_handler._create_new_client_event(
80 event, context = yield self.event_creation_handler.create_new_client_event(
8281 builder
8382 )
8483
9796 "redacts": event_id,
9897 })
9998
100 event, context = yield self.message_handler._create_new_client_event(
99 event, context = yield self.event_creation_handler.create_new_client_event(
101100 builder
102101 )
103102
8585
8686 # now delete some
8787 yield self.store.user_delete_access_tokens(
88 self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
88 self.user_id, device_id=self.device_id,
89 )
8990
9091 # check they were deleted
9192 user = yield self.store.get_user_by_access_token(self.tokens[1])
9697 self.assertEqual(self.user_id, user["name"])
9798
9899 # now delete the rest
99 yield self.store.user_delete_access_tokens(
100 self.user_id, delete_refresh_tokens=True)
100 yield self.store.user_delete_access_tokens(self.user_id)
101101
102102 user = yield self.store.get_user_by_access_token(self.tokens[0])
103103 self.assertIsNone(user,
3636 # storage logic
3737 self.store = hs.get_datastore()
3838 self.event_builder_factory = hs.get_event_builder_factory()
39 self.handlers = hs.get_handlers()
40 self.message_handler = self.handlers.message_handler
39 self.event_creation_handler = hs.get_event_creation_handler()
4140
4241 self.u_alice = UserID.from_string("@alice:test")
4342 self.u_bob = UserID.from_string("@bob:test")
5756 "content": {"membership": membership},
5857 })
5958
60 event, context = yield self.message_handler._create_new_client_event(
59 event, context = yield self.event_creation_handler.create_new_client_event(
6160 builder
6261 )
6362
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from twisted.internet import defer
16
17 from synapse.storage import UserDirectoryStore
18 from synapse.storage.roommember import ProfileInfo
19 from tests import unittest
20 from tests.utils import setup_test_homeserver
21
22 ALICE = "@alice:a"
23 BOB = "@bob:b"
24 BOBBY = "@bobby:a"
25
26
27 class UserDirectoryStoreTestCase(unittest.TestCase):
28 @defer.inlineCallbacks
29 def setUp(self):
30 self.hs = yield setup_test_homeserver()
31 self.store = UserDirectoryStore(None, self.hs)
32
33 # alice and bob are both in !room_id. bobby is not but shares
34 # a homeserver with alice.
35 yield self.store.add_profiles_to_user_dir(
36 "!room:id",
37 {
38 ALICE: ProfileInfo(None, "alice"),
39 BOB: ProfileInfo(None, "bob"),
40 BOBBY: ProfileInfo(None, "bobby")
41 },
42 )
43 yield self.store.add_users_to_public_room(
44 "!room:id",
45 [ALICE, BOB],
46 )
47 yield self.store.add_users_who_share_room(
48 "!room:id",
49 False,
50 (
51 (ALICE, BOB),
52 (BOB, ALICE),
53 ),
54 )
55
56 @defer.inlineCallbacks
57 def test_search_user_dir(self):
58 # normally when alice searches the directory she should just find
59 # bob because bobby doesn't share a room with her.
60 r = yield self.store.search_user_dir(ALICE, "bob", 10)
61 self.assertFalse(r["limited"])
62 self.assertEqual(1, len(r["results"]))
63 self.assertDictEqual(r["results"][0], {
64 "user_id": BOB,
65 "display_name": "bob",
66 "avatar_url": None,
67 })
68
69 @defer.inlineCallbacks
70 def test_search_user_dir_all_users(self):
71 self.hs.config.user_directory_search_all_users = True
72 try:
73 r = yield self.store.search_user_dir(ALICE, "bob", 10)
74 self.assertFalse(r["limited"])
75 self.assertEqual(2, len(r["results"]))
76 self.assertDictEqual(r["results"][0], {
77 "user_id": BOB,
78 "display_name": "bob",
79 "avatar_url": None,
80 })
81 self.assertDictEqual(r["results"][1], {
82 "user_id": BOBBY,
83 "display_name": "bobby",
84 "avatar_url": None,
85 })
86 finally:
87 self.hs.config.user_directory_search_all_users = False
1818 from synapse.events import FrozenEvent
1919 from synapse.api.auth import Auth
2020 from synapse.api.constants import EventTypes, Membership
21 from synapse.state import StateHandler
21 from synapse.state import StateHandler, StateResolutionHandler
2222
2323 from .utils import MockClock
2424
7979
8080 return defer.succeed(groups)
8181
82 def store_state_groups(self, event, context):
83 if context.current_state_ids is None:
84 return
85
86 state_events = dict(context.current_state_ids)
87
88 self._group_to_state[context.state_group] = state_events
89 self._event_to_state_group[event.event_id] = context.state_group
82 def store_state_group(self, event_id, room_id, prev_group, delta_ids,
83 current_state_ids):
84 state_group = self._next_group
85 self._next_group += 1
86
87 self._group_to_state[state_group] = dict(current_state_ids)
88
89 return state_group
9090
9191 def get_events(self, event_ids, **kwargs):
9292 return {
9494 if e_id in self._event_id_to_event
9595 }
9696
97 def get_state_group_delta(self, name):
98 return (None, None)
99
97100 def register_events(self, events):
98101 for e in events:
99102 self._event_id_to_event[e.event_id] = e
103
104 def register_event_context(self, event, context):
105 self._event_to_state_group[event.event_id] = context.state_group
106
107 def register_event_id_state_group(self, event_id, state_group):
108 self._event_to_state_group[event_id] = state_group
100109
101110
102111 class DictObj(dict):
136145
137146 class StateTestCase(unittest.TestCase):
138147 def setUp(self):
139 self.store = Mock(
140 spec_set=[
141 "get_state_groups_ids",
142 "add_event_hashes",
143 "get_events",
144 "get_next_state_group",
145 "get_state_group_delta",
146 ]
147 )
148 self.store = StateGroupStore()
148149 hs = Mock(spec_set=[
149150 "get_datastore", "get_auth", "get_state_handler", "get_clock",
151 "get_state_resolution_handler",
150152 ])
151153 hs.get_datastore.return_value = self.store
152154 hs.get_state_handler.return_value = None
153155 hs.get_clock.return_value = MockClock()
154156 hs.get_auth.return_value = Auth(hs)
155
156 self.store.get_next_state_group.side_effect = Mock
157 self.store.get_state_group_delta.return_value = (None, None)
157 hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
158158
159159 self.state = StateHandler(hs)
160160 self.event_id = 0
194194 }
195195 )
196196
197 store = StateGroupStore()
198 self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
197 self.store.register_events(graph.walk())
199198
200199 context_store = {}
201200
202201 for event in graph.walk():
203202 context = yield self.state.compute_event_context(event)
204 store.store_state_groups(event, context)
203 self.store.register_event_context(event, context)
205204 context_store[event.event_id] = context
206205
207206 self.assertEqual(2, len(context_store["D"].prev_state_ids))
246245 }
247246 )
248247
249 store = StateGroupStore()
250 self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
251 self.store.get_events = store.get_events
252 store.register_events(graph.walk())
248 self.store.register_events(graph.walk())
253249
254250 context_store = {}
255251
256252 for event in graph.walk():
257253 context = yield self.state.compute_event_context(event)
258 store.store_state_groups(event, context)
254 self.store.register_event_context(event, context)
259255 context_store[event.event_id] = context
260256
261257 self.assertSetEqual(
312308 }
313309 )
314310
315 store = StateGroupStore()
316 self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
317 self.store.get_events = store.get_events
318 store.register_events(graph.walk())
311 self.store.register_events(graph.walk())
319312
320313 context_store = {}
321314
322315 for event in graph.walk():
323316 context = yield self.state.compute_event_context(event)
324 store.store_state_groups(event, context)
317 self.store.register_event_context(event, context)
325318 context_store[event.event_id] = context
326319
327320 self.assertSetEqual(
395388 self._add_depths(nodes, edges)
396389 graph = Graph(nodes, edges)
397390
398 store = StateGroupStore()
399 self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
400 self.store.get_events = store.get_events
401 store.register_events(graph.walk())
391 self.store.register_events(graph.walk())
402392
403393 context_store = {}
404394
405395 for event in graph.walk():
406396 context = yield self.state.compute_event_context(event)
407 store.store_state_groups(event, context)
397 self.store.register_event_context(event, context)
408398 context_store[event.event_id] = context
409399
410400 self.assertSetEqual(
464454
465455 @defer.inlineCallbacks
466456 def test_trivial_annotate_message(self):
467 event = create_event(type="test_message", name="event")
457 prev_event_id = "prev_event_id"
458 event = create_event(
459 type="test_message", name="event2",
460 prev_events=[(prev_event_id, {})],
461 )
468462
469463 old_state = [
470464 create_event(type="test1", state_key="1"),
472466 create_event(type="test2", state_key=""),
473467 ]
474468
475 group_name = "group_name_1"
476
477 self.store.get_state_groups_ids.return_value = {
478 group_name: {(e.type, e.state_key): e.event_id for e in old_state},
479 }
469 group_name = self.store.store_state_group(
470 prev_event_id, event.room_id, None, None,
471 {(e.type, e.state_key): e.event_id for e in old_state},
472 )
473 self.store.register_event_id_state_group(prev_event_id, group_name)
480474
481475 context = yield self.state.compute_event_context(event)
482476
489483
490484 @defer.inlineCallbacks
491485 def test_trivial_annotate_state(self):
492 event = create_event(type="state", state_key="", name="event")
486 prev_event_id = "prev_event_id"
487 event = create_event(
488 type="state", state_key="", name="event2",
489 prev_events=[(prev_event_id, {})],
490 )
493491
494492 old_state = [
495493 create_event(type="test1", state_key="1"),
497495 create_event(type="test2", state_key=""),
498496 ]
499497
500 group_name = "group_name_1"
501
502 self.store.get_state_groups_ids.return_value = {
503 group_name: {(e.type, e.state_key): e.event_id for e in old_state},
504 }
498 group_name = self.store.store_state_group(
499 prev_event_id, event.room_id, None, None,
500 {(e.type, e.state_key): e.event_id for e in old_state},
501 )
502 self.store.register_event_id_state_group(prev_event_id, group_name)
505503
506504 context = yield self.state.compute_event_context(event)
507505
514512
515513 @defer.inlineCallbacks
516514 def test_resolve_message_conflict(self):
517 event = create_event(type="test_message", name="event")
515 prev_event_id1 = "event_id1"
516 prev_event_id2 = "event_id2"
517 event = create_event(
518 type="test_message", name="event3",
519 prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
520 )
521
522 creation = create_event(
523 type=EventTypes.Create, state_key=""
524 )
525
526 old_state_1 = [
527 creation,
528 create_event(type="test1", state_key="1"),
529 create_event(type="test1", state_key="2"),
530 create_event(type="test2", state_key=""),
531 ]
532
533 old_state_2 = [
534 creation,
535 create_event(type="test1", state_key="1"),
536 create_event(type="test3", state_key="2"),
537 create_event(type="test4", state_key=""),
538 ]
539
540 self.store.register_events(old_state_1)
541 self.store.register_events(old_state_2)
542
543 context = yield self._get_context(
544 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
545 )
546
547 self.assertEqual(len(context.current_state_ids), 6)
548
549 self.assertIsNotNone(context.state_group)
550
551 @defer.inlineCallbacks
552 def test_resolve_state_conflict(self):
553 prev_event_id1 = "event_id1"
554 prev_event_id2 = "event_id2"
555 event = create_event(
556 type="test4", state_key="", name="event",
557 prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
558 )
518559
519560 creation = create_event(
520561 type=EventTypes.Create, state_key=""
539580 store.register_events(old_state_2)
540581 self.store.get_events = store.get_events
541582
542 context = yield self._get_context(event, old_state_1, old_state_2)
583 context = yield self._get_context(
584 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
585 )
543586
544587 self.assertEqual(len(context.current_state_ids), 6)
545588
546589 self.assertIsNotNone(context.state_group)
547590
548591 @defer.inlineCallbacks
549 def test_resolve_state_conflict(self):
550 event = create_event(type="test4", state_key="", name="event")
551
552 creation = create_event(
553 type=EventTypes.Create, state_key=""
554 )
555
556 old_state_1 = [
557 creation,
558 create_event(type="test1", state_key="1"),
559 create_event(type="test1", state_key="2"),
560 create_event(type="test2", state_key=""),
561 ]
562
563 old_state_2 = [
564 creation,
565 create_event(type="test1", state_key="1"),
566 create_event(type="test3", state_key="2"),
567 create_event(type="test4", state_key=""),
568 ]
569
570 store = StateGroupStore()
571 store.register_events(old_state_1)
572 store.register_events(old_state_2)
573 self.store.get_events = store.get_events
574
575 context = yield self._get_context(event, old_state_1, old_state_2)
576
577 self.assertEqual(len(context.current_state_ids), 6)
578
579 self.assertIsNotNone(context.state_group)
580
581 @defer.inlineCallbacks
582592 def test_standard_depth_conflict(self):
583 event = create_event(type="test4", name="event")
593 prev_event_id1 = "event_id1"
594 prev_event_id2 = "event_id2"
595 event = create_event(
596 type="test4", name="event",
597 prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
598 )
584599
585600 member_event = create_event(
586601 type=EventTypes.Member,
612627 store.register_events(old_state_2)
613628 self.store.get_events = store.get_events
614629
615 context = yield self._get_context(event, old_state_1, old_state_2)
630 context = yield self._get_context(
631 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
632 )
616633
617634 self.assertEqual(
618635 old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
636653 store.register_events(old_state_1)
637654 store.register_events(old_state_2)
638655
639 context = yield self._get_context(event, old_state_1, old_state_2)
656 context = yield self._get_context(
657 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
658 )
640659
641660 self.assertEqual(
642661 old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
643662 )
644663
645 def _get_context(self, event, old_state_1, old_state_2):
646 group_name_1 = "group_name_1"
647 group_name_2 = "group_name_2"
648
649 self.store.get_state_groups_ids.return_value = {
650 group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
651 group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
652 }
664 def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
665 old_state_2):
666 sg1 = self.store.store_state_group(
667 prev_event_id_1, event.room_id, None, None,
668 {(e.type, e.state_key): e.event_id for e in old_state_1},
669 )
670 self.store.register_event_id_state_group(prev_event_id_1, sg1)
671
672 sg2 = self.store.store_state_group(
673 prev_event_id_2, event.room_id, None, None,
674 {(e.type, e.state_key): e.event_id for e in old_state_2},
675 )
676 self.store.register_event_id_state_group(prev_event_id_2, sg2)
653677
654678 return self.state.compute_event_context(event)
1616
1717 from synapse.api.errors import SynapseError
1818 from synapse.server import HomeServer
19 from synapse.types import UserID, RoomAlias
19 from synapse.types import UserID, RoomAlias, GroupID
2020
2121 mock_homeserver = HomeServer(hostname="my.domain")
2222
5959 room = RoomAlias("channel", "my.domain")
6060
6161 self.assertEquals(room.to_string(), "#channel:my.domain")
62
63
64 class GroupIDTestCase(unittest.TestCase):
65 def test_parse(self):
66 group_id = GroupID.from_string("+group/=_-.123:my.domain")
67 self.assertEqual("group/=_-.123", group_id.localpart)
68 self.assertEqual("my.domain", group_id.domain)
69
70 def test_validate(self):
71 bad_ids = [
72 "$badsigil:domain",
73 "+:empty",
74 ] + [
75 "+group" + c + ":domain" for c in "A%?æ£"
76 ]
77 for id_string in bad_ids:
78 try:
79 GroupID.from_string(id_string)
80 self.fail("Parsing '%s' should raise exception" % id_string)
81 except SynapseError as exc:
82 self.assertEqual(400, exc.code)
83 self.assertEqual("M_UNKNOWN", exc.errcode)
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
14 import twisted
1515 from twisted.trial import unittest
1616
1717 import logging
6464
6565 @around(self)
6666 def setUp(orig):
67 # enable debugging of delayed calls - this means that we get a
68 # traceback when a unit test exits leaving things on the reactor.
69 twisted.internet.base.DelayedCall.debug = True
70
6771 old_level = logging.getLogger().level
6872
6973 if old_level != level:
0 # -*- coding: utf-8 -*-
1 # Copyright 2018 New Vector Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15
16 from twisted.internet import defer, reactor
17 from mock import NonCallableMock
18
19 from synapse.util.file_consumer import BackgroundFileConsumer
20
21 from tests import unittest
22 from StringIO import StringIO
23
24 import threading
25
26
27 class FileConsumerTests(unittest.TestCase):
28
29 @defer.inlineCallbacks
30 def test_pull_consumer(self):
31 string_file = StringIO()
32 consumer = BackgroundFileConsumer(string_file)
33
34 try:
35 producer = DummyPullProducer()
36
37 yield producer.register_with_consumer(consumer)
38
39 yield producer.write_and_wait("Foo")
40
41 self.assertEqual(string_file.getvalue(), "Foo")
42
43 yield producer.write_and_wait("Bar")
44
45 self.assertEqual(string_file.getvalue(), "FooBar")
46 finally:
47 consumer.unregisterProducer()
48
49 yield consumer.wait()
50
51 self.assertTrue(string_file.closed)
52
53 @defer.inlineCallbacks
54 def test_push_consumer(self):
55 string_file = BlockingStringWrite()
56 consumer = BackgroundFileConsumer(string_file)
57
58 try:
59 producer = NonCallableMock(spec_set=[])
60
61 consumer.registerProducer(producer, True)
62
63 consumer.write("Foo")
64 yield string_file.wait_for_n_writes(1)
65
66 self.assertEqual(string_file.buffer, "Foo")
67
68 consumer.write("Bar")
69 yield string_file.wait_for_n_writes(2)
70
71 self.assertEqual(string_file.buffer, "FooBar")
72 finally:
73 consumer.unregisterProducer()
74
75 yield consumer.wait()
76
77 self.assertTrue(string_file.closed)
78
79 @defer.inlineCallbacks
80 def test_push_producer_feedback(self):
81 string_file = BlockingStringWrite()
82 consumer = BackgroundFileConsumer(string_file)
83
84 try:
85 producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
86
87 resume_deferred = defer.Deferred()
88 producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
89
90 consumer.registerProducer(producer, True)
91
92 number_writes = 0
93 with string_file.write_lock:
94 for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
95 consumer.write("Foo")
96 number_writes += 1
97
98 producer.pauseProducing.assert_called_once()
99
100 yield string_file.wait_for_n_writes(number_writes)
101
102 yield resume_deferred
103 producer.resumeProducing.assert_called_once()
104 finally:
105 consumer.unregisterProducer()
106
107 yield consumer.wait()
108
109 self.assertTrue(string_file.closed)
110
111
112 class DummyPullProducer(object):
113 def __init__(self):
114 self.consumer = None
115 self.deferred = defer.Deferred()
116
117 def resumeProducing(self):
118 d = self.deferred
119 self.deferred = defer.Deferred()
120 d.callback(None)
121
122 def write_and_wait(self, bytes):
123 d = self.deferred
124 self.consumer.write(bytes)
125 return d
126
127 def register_with_consumer(self, consumer):
128 d = self.deferred
129 self.consumer = consumer
130 self.consumer.registerProducer(self, False)
131 return d
132
133
134 class BlockingStringWrite(object):
135 def __init__(self):
136 self.buffer = ""
137 self.closed = False
138 self.write_lock = threading.Lock()
139
140 self._notify_write_deferred = None
141 self._number_of_writes = 0
142
143 def write(self, bytes):
144 with self.write_lock:
145 self.buffer += bytes
146 self._number_of_writes += 1
147
148 reactor.callFromThread(self._notify_write)
149
150 def close(self):
151 self.closed = True
152
153 def _notify_write(self):
154 "Called by write to indicate a write happened"
155 with self.write_lock:
156 if not self._notify_write_deferred:
157 return
158 d = self._notify_write_deferred
159 self._notify_write_deferred = None
160 d.callback(None)
161
162 @defer.inlineCallbacks
163 def wait_for_n_writes(self, n):
164 "Wait for n writes to have happened"
165 while True:
166 with self.write_lock:
167 if n <= self._number_of_writes:
168 return
169
170 if not self._notify_write_deferred:
171 self._notify_write_deferred = defer.Deferred()
172
173 d = self._notify_write_deferred
174
175 yield d
1111
1212 def _check_test_key(self, value):
1313 self.assertEquals(
14 LoggingContext.current_context().test_key, value
14 LoggingContext.current_context().request, value
1515 )
1616
1717 def test_with_context(self):
1818 with LoggingContext() as context_one:
19 context_one.test_key = "test"
19 context_one.request = "test"
2020 self._check_test_key("test")
2121
2222 @defer.inlineCallbacks
2424 @defer.inlineCallbacks
2525 def competing_callback():
2626 with LoggingContext() as competing_context:
27 competing_context.test_key = "competing"
27 competing_context.request = "competing"
2828 yield sleep(0)
2929 self._check_test_key("competing")
3030
3131 reactor.callLater(0, competing_callback)
3232
3333 with LoggingContext() as context_one:
34 context_one.test_key = "one"
34 context_one.request = "one"
3535 yield sleep(0)
3636 self._check_test_key("one")
3737
4242
4343 @defer.inlineCallbacks
4444 def cb():
45 context_one.test_key = "one"
45 context_one.request = "one"
4646 yield function()
4747 self._check_test_key("one")
4848
4949 callback_completed[0] = True
5050
5151 with LoggingContext() as context_one:
52 context_one.test_key = "one"
52 context_one.request = "one"
5353
5454 # fire off function, but don't wait on it.
5555 logcontext.preserve_fn(cb)()
106106 sentinel_context = LoggingContext.current_context()
107107
108108 with LoggingContext() as context_one:
109 context_one.test_key = "one"
109 context_one.request = "one"
110110
111111 d1 = logcontext.make_deferred_yieldable(blocking_function())
112112 # make sure that the context was reset by make_deferred_yieldable
123123 argument isn't actually a deferred"""
124124
125125 with LoggingContext() as context_one:
126 context_one.test_key = "one"
126 context_one.request = "one"
127127
128128 d1 = logcontext.make_deferred_yieldable("bum")
129129 self._check_test_key("one")
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from synapse.http.server import HttpServer
16 from synapse.api.errors import cs_error, CodeMessageException, StoreError
17 from synapse.api.constants import EventTypes
18 from synapse.storage.prepare_database import prepare_database
19 from synapse.storage.engines import create_engine
20 from synapse.server import HomeServer
21 from synapse.federation.transport import server
22 from synapse.util.ratelimitutils import FederationRateLimiter
23
24 from synapse.util.logcontext import LoggingContext
25
26 from twisted.internet import defer, reactor
27 from twisted.enterprise.adbapi import ConnectionPool
28
29 from collections import namedtuple
30 from mock import patch, Mock
3115 import hashlib
16 from inspect import getcallargs
3217 import urllib
3318 import urlparse
3419
35 from inspect import getcallargs
20 from mock import Mock, patch
21 from twisted.internet import defer, reactor
22
23 from synapse.api.errors import CodeMessageException, cs_error
24 from synapse.federation.transport import server
25 from synapse.http.server import HttpServer
26 from synapse.server import HomeServer
27 from synapse.storage import PostgresEngine
28 from synapse.storage.engines import create_engine
29 from synapse.storage.prepare_database import prepare_database
30 from synapse.util.logcontext import LoggingContext
31 from synapse.util.ratelimitutils import FederationRateLimiter
32
33 # set this to True to run the tests against postgres instead of sqlite.
34 # It requires you to have a local postgres database called synapse_test, within
35 # which ALL TABLES WILL BE DROPPED
36 USE_POSTGRES_FOR_TESTS = False
3637
3738
3839 @defer.inlineCallbacks
5657 config.worker_app = None
5758 config.email_enable_notifs = False
5859 config.block_non_admin_invites = False
60 config.federation_domain_whitelist = None
61 config.user_directory_search_all_users = False
62
63 # disable user directory updates, because they get done in the
64 # background, which upsets the test runner.
65 config.update_user_directory = False
5966
6067 config.use_frozen_dicts = True
61 config.database_config = {"name": "sqlite3"}
6268 config.ldap_enabled = False
6369
6470 if "clock" not in kargs:
6571 kargs["clock"] = MockClock()
6672
73 if USE_POSTGRES_FOR_TESTS:
74 config.database_config = {
75 "name": "psycopg2",
76 "args": {
77 "database": "synapse_test",
78 "cp_min": 1,
79 "cp_max": 5,
80 },
81 }
82 else:
83 config.database_config = {
84 "name": "sqlite3",
85 "args": {
86 "database": ":memory:",
87 "cp_min": 1,
88 "cp_max": 1,
89 },
90 }
91
92 db_engine = create_engine(config.database_config)
93
94 # we need to configure the connection pool to run the on_new_connection
95 # function, so that we can test code that uses custom sqlite functions
96 # (like rank).
97 config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
98
6799 if datastore is None:
68 db_pool = SQLiteMemoryDbPool()
69 yield db_pool.prepare()
70100 hs = HomeServer(
71 name, db_pool=db_pool, config=config,
101 name, config=config,
102 db_config=config.database_config,
72103 version_string="Synapse/tests",
73 database_engine=create_engine(config.database_config),
74 get_db_conn=db_pool.get_db_conn,
104 database_engine=db_engine,
75105 room_list_handler=object(),
76106 tls_server_context_factory=Mock(),
77107 **kargs
78108 )
109 db_conn = hs.get_db_conn()
110 # make sure that the database is empty
111 if isinstance(db_engine, PostgresEngine):
112 cur = db_conn.cursor()
113 cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
114 rows = cur.fetchall()
115 for r in rows:
116 cur.execute("DROP TABLE %s CASCADE" % r[0])
117 yield prepare_database(db_conn, db_engine, config)
79118 hs.setup()
80119 else:
81120 hs = HomeServer(
82121 name, db_pool=None, datastore=datastore, config=config,
83122 version_string="Synapse/tests",
84 database_engine=create_engine(config.database_config),
123 database_engine=db_engine,
85124 room_list_handler=object(),
86125 tls_server_context_factory=Mock(),
87126 **kargs
183222 mock_request.args = urlparse.parse_qs(path.split('?')[1])
184223 mock_request.path = path.split('?')[0]
185224 path = mock_request.path
186 except:
225 except Exception:
187226 pass
188227
189228 for (method, pattern, func) in self.callbacks:
300339 return d
301340
302341
303 class SQLiteMemoryDbPool(ConnectionPool, object):
304 def __init__(self):
305 super(SQLiteMemoryDbPool, self).__init__(
306 "sqlite3", ":memory:",
307 cp_min=1,
308 cp_max=1,
309 )
310
311 self.config = Mock()
312 self.config.database_config = {"name": "sqlite3"}
313
314 def prepare(self):
315 engine = self.create_engine()
316 return self.runWithConnection(
317 lambda conn: prepare_database(conn, engine, self.config)
318 )
319
320 def get_db_conn(self):
321 conn = self.connect()
322 engine = self.create_engine()
323 prepare_database(conn, engine, self.config)
324 return conn
325
326 def create_engine(self):
327 return create_engine(self.config.database_config)
328
329
330 class MemoryDataStore(object):
331
332 Room = namedtuple(
333 "Room",
334 ["room_id", "is_public", "creator"]
335 )
336
337 def __init__(self):
338 self.tokens_to_users = {}
339 self.paths_to_content = {}
340
341 self.members = {}
342 self.rooms = {}
343
344 self.current_state = {}
345 self.events = []
346
347 class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
348 def fill_out_prev_events(self, event):
349 pass
350
351 def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
352 return self.Snapshot(
353 room_id, user_id, self.get_room_member(user_id, room_id)
354 )
355
356 def register(self, user_id, token, password_hash):
357 if user_id in self.tokens_to_users.values():
358 raise StoreError(400, "User in use.")
359 self.tokens_to_users[token] = user_id
360
361 def get_user_by_access_token(self, token):
362 try:
363 return {
364 "name": self.tokens_to_users[token],
365 }
366 except:
367 raise StoreError(400, "User does not exist.")
368
369 def get_room(self, room_id):
370 try:
371 return self.rooms[room_id]
372 except:
373 return None
374
375 def store_room(self, room_id, room_creator_user_id, is_public):
376 if room_id in self.rooms:
377 raise StoreError(409, "Conflicting room!")
378
379 room = MemoryDataStore.Room(
380 room_id=room_id,
381 is_public=is_public,
382 creator=room_creator_user_id
383 )
384 self.rooms[room_id] = room
385
386 def get_room_member(self, user_id, room_id):
387 return self.members.get(room_id, {}).get(user_id)
388
389 def get_room_members(self, room_id, membership=None):
390 if membership:
391 return [
392 v for k, v in self.members.get(room_id, {}).items()
393 if v.membership == membership
394 ]
395 else:
396 return self.members.get(room_id, {}).values()
397
398 def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
399 return [
400 m[user_id] for m in self.members.values()
401 if user_id in m and m[user_id].membership in membership_list
402 ]
403
404 def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
405 limit=0, with_feedback=False):
406 return ([], from_key) # TODO
407
408 def get_joined_hosts_for_room(self, room_id):
409 return defer.succeed([])
410
411 def persist_event(self, event):
412 if event.type == EventTypes.Member:
413 room_id = event.room_id
414 user = event.state_key
415 self.members.setdefault(room_id, {})[user] = event
416
417 if hasattr(event, "state_key"):
418 key = (event.room_id, event.type, event.state_key)
419 self.current_state[key] = event
420
421 self.events.append(event)
422
423 def get_current_state(self, room_id, event_type=None, state_key=""):
424 if event_type:
425 key = (room_id, event_type, state_key)
426 if self.current_state.get(key):
427 return [self.current_state.get(key)]
428 return None
429 else:
430 return [
431 e for e in self.current_state
432 if e[0] == room_id
433 ]
434
435 def set_presence_state(self, user_localpart, state):
436 return defer.succeed({"state": 0})
437
438 def get_presence_list(self, user_localpart, accepted):
439 return []
440
441 def get_room_events_max_id(self):
442 return "s0" # TODO (erikj)
443
444 def get_send_event_level(self, room_id):
445 return defer.succeed(0)
446
447 def get_power_level(self, room_id, user_id):
448 return defer.succeed(0)
449
450 def get_add_state_level(self, room_id):
451 return defer.succeed(0)
452
453 def get_room_join_rule(self, room_id):
454 # TODO (erikj): This should be configurable
455 return defer.succeed("invite")
456
457 def get_ops_levels(self, room_id):
458 return defer.succeed((5, 5, 5))
459
460 def insert_client_ip(self, user, access_token, ip, user_agent):
461 return defer.succeed(None)
462
463
464342 def _format_call(args, kwargs):
465343 return ", ".join(
466344 ["%r" % (a) for a in args] +
498376 for _, _, d in self.expectations:
499377 try:
500378 d.errback(failure)
501 except:
379 except Exception:
502380 pass
503381
504382 raise failure