Codebase list matrix-synapse / c263062
Merge branch 'debian/unstable' into debian/buster-fasttrack Andrej Shadura 3 years ago
104 changed file(s) with 2728 addition(s) and 954 deletion(s). Raw diff Collapse all Expand all
0 # Black reformatting (#5482).
1 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
2
3 # Target Python 3.5 with black (#8664).
4 aff1eb7c671b0a3813407321d2702ec46c71fa56
5
6 # Update black to 20.8b1 (#9381).
7 0a00b7ff14890987f09112a2ae696c61001e6cf1
0 Synapse 1.30.0 (2021-03-22)
1 ===========================
2
3 Note that this release deprecates the ability for appservices to
4 call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice
5 developers should use a `type` value of `m.login.application_service` as
6 per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions).
7 In future releases, calling this endpoint with an access token - but without a `m.login.application_service`
8 type - will fail.
9
10
11 No significant changes.
12
13
14 Synapse 1.30.0rc1 (2021-03-16)
15 ==============================
16
17 Features
18 --------
19
20 - Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573))
21 - Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
22 - Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549))
23 - Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601))
24 - Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617))
25 - Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626))
26
27
28 Bugfixes
29 --------
30
31 - Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473))
32 - Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583))
33 - Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567))
34 - Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587))
35 - Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597))
36 - Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620))
37 - Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623))
38
39
40 Updates to the Docker image
41 ---------------------------
42
43 - Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553))
44
45
46 Improved Documentation
47 ----------------------
48
49 - Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508))
50 - Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550))
51 - Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571))
52 - Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580))
53 - Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
54
55
56 Deprecations and Removals
57 -------------------------
58
59 - The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
60 - Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559))
61
62
63 Internal Changes
64 ----------------
65
66 - Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458))
67 - Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520))
68 - Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523))
69 - Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618))
70 - Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541))
71 - Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560))
72 - Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561))
73 - Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562))
74 - Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563))
75 - Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568))
76 - Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576))
77 - Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586))
78 - Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590))
79 - Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596))
80 - Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
81 - Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619))
82
83
084 Synapse 1.29.0 (2021-03-08)
185 ===========================
286
1919 recursive-include scripts-dev *
2020 recursive-include synapse *.pyi
2121 recursive-include tests *.py
22 include tests/http/ca.crt
23 include tests/http/ca.key
24 include tests/http/server.key
22 recursive-include tests *.pem
23 recursive-include tests *.p8
24 recursive-include tests *.crt
25 recursive-include tests *.key
2526
2627 recursive-include synapse/res *
2728 recursive-include synapse/static *.css
182182 It is recommended to put a reverse proxy such as
183183 `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
184184 `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
185 `Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
186 `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
185 `Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
186 `HAProxy <https://www.haproxy.org/>`_ or
187 `relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
187188 doing so is that it means that you can expose the default https port (443) to
188189 Matrix clients without needing to run Synapse with root privileges.
189190
122122 * If your server is configured for single sign-on via a SAML2 identity provider, you will
123123 need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
124124 "ACS location" (also known as "allowed callback URLs") at the identity provider.
125
126 The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
127 ``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
128 provider uses this property to validate or otherwise identify Synapse, its configuration
129 will need to be updated to use the new URL. Alternatively you could create a new, separate
130 "EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
131 the existing "EntityDescriptor" as they were.
125132
126133 Changes to HTML templates
127134 -------------------------
0 matrix-synapse (1.30.0-1) unstable; urgency=medium
1
2 * New upstream release.
3 * Update the watch URL.
4
5 -- Andrej Shadura <andrewsh@debian.org> Mon, 22 Mar 2021 18:36:56 +0100
6
07 matrix-synapse (1.29.0-1~fto10+1) buster-fasttrack; urgency=medium
18
29 * Rebuild for buster-fasttrack.
11 version=3
22
33 opts=filenamemangle=s/.+\/v?(\d\S*)\.tar\.gz/matrix-synapse-$1\.tar\.gz/,uversionmangle=s/-?rc/~rc/,repacksuffix=+dfsg,dversionmangle=s/\+dfsg$// \
4 https://github.com/matrix-org/synapse/tags .*/archive/v(\d[^\s\-]*)\.tar\.gz
4 https://github.com/matrix-org/synapse/tags .*/archive/(?:refs/tags/)?v(\d[^\s\-]*)\.tar\.gz
6868 libpq5 \
6969 libwebp6 \
7070 xmlsec1 \
71 libjemalloc2 \
7172 && rm -rf /var/lib/apt/lists/*
7273
7374 COPY --from=builder /install /usr/local
203203 timeout: 10s
204204 retries: 3
205205 ```
206
207 ## Using jemalloc
208
209 Jemalloc is embedded in the image and will be used instead of the default allocator.
210 You can read about jemalloc by reading the Synapse [README](../README.md)
22 import codecs
33 import glob
44 import os
5 import platform
56 import subprocess
67 import sys
78
212213 if "-m" not in args:
213214 args = ["-m", synapse_worker] + args
214215
216 jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
217
218 if os.path.isfile(jemallocpath):
219 environ["LD_PRELOAD"] = jemallocpath
220 else:
221 log("Could not find %s, will not use" % (jemallocpath,))
222
215223 # if there are no config files passed to synapse, try adding the default file
216224 if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
217225 config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
247255 args = ["python"] + args
248256 if ownership is not None:
249257 args = ["gosu", ownership] + args
250 os.execv("/usr/sbin/gosu", args)
251 else:
252 os.execv("/usr/local/bin/python", args)
258 os.execve("/usr/sbin/gosu", args, environ)
259 else:
260 os.execve("/usr/local/bin/python", args, environ)
253261
254262
255263 if __name__ == "__main__":
00 # Contents
1 - [List all media in a room](#list-all-media-in-a-room)
1 - [Querying media](#querying-media)
2 * [List all media in a room](#list-all-media-in-a-room)
3 * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user)
24 - [Quarantine media](#quarantine-media)
35 * [Quarantining media by ID](#quarantining-media-by-id)
46 * [Quarantining media in a room](#quarantining-media-in-a-room)
911 * [Delete local media by date or size](#delete-local-media-by-date-or-size)
1012 - [Purge Remote Media API](#purge-remote-media-api)
1113
12 # List all media in a room
14 # Querying media
15
16 These APIs allow extracting media information from the homeserver.
17
18 ## List all media in a room
1319
1420 This API gets a list of known media in a room.
1521 However, it only shows media from unencrypted events or rooms.
3440 ]
3541 }
3642 ```
43
44 ## List all media uploaded by a user
45
46 Listing all media that has been uploaded by a local user can be achieved through
47 the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user)
48 Admin API.
3749
3850 # Quarantine media
3951
225225 oidc_providers:
226226 - idp_id: github
227227 idp_name: Github
228 idp_brand: "org.matrix.github" # optional: styling hint for clients
228 idp_brand: "github" # optional: styling hint for clients
229229 discover: false
230230 issuer: "https://github.com/"
231231 client_id: "your-client-id" # TO BE FILLED
251251 oidc_providers:
252252 - idp_id: google
253253 idp_name: Google
254 idp_brand: "org.matrix.google" # optional: styling hint for clients
254 idp_brand: "google" # optional: styling hint for clients
255255 issuer: "https://accounts.google.com/"
256256 client_id: "your-client-id" # TO BE FILLED
257257 client_secret: "your-client-secret" # TO BE FILLED
298298 oidc_providers:
299299 - idp_id: gitlab
300300 idp_name: Gitlab
301 idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
301 idp_brand: "gitlab" # optional: styling hint for clients
302302 issuer: "https://gitlab.com/"
303303 client_id: "your-client-id" # TO BE FILLED
304304 client_secret: "your-client-secret" # TO BE FILLED
333333 ```yaml
334334 - idp_id: facebook
335335 idp_name: Facebook
336 idp_brand: "org.matrix.facebook" # optional: styling hint for clients
336 idp_brand: "facebook" # optional: styling hint for clients
337337 discover: false
338338 issuer: "https://facebook.com"
339339 client_id: "your-client-id" # TO BE FILLED
385385 config:
386386 subject_claim: "id"
387387 localpart_template: "{{ user.login }}"
388 display_name_template: "{{ user.full_name }}"
388 display_name_template: "{{ user.full_name }}"
389389 ```
390390
391391 ### XWiki
400400 idp_name: "XWiki"
401401 issuer: "https://myxwikihost/xwiki/oidc/"
402402 client_id: "your-client-id" # TO BE FILLED
403 # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
404 client_secret: "dontcare"
403 client_auth_method: none
405404 scopes: ["openid", "profile"]
406405 user_profile_method: "userinfo_endpoint"
407406 user_mapping_provider:
409408 localpart_template: "{{ user.preferred_username }}"
410409 display_name_template: "{{ user.name }}"
411410 ```
411
412 ## Apple
413
414 Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
415
416 You will need to create a new "Services ID" for SiWA, and create and download a
417 private key with "SiWA" enabled.
418
419 As well as the private key file, you will need:
420 * Client ID: the "identifier" you gave the "Services ID"
421 * Team ID: a 10-character ID associated with your developer account.
422 * Key ID: the 10-character identifier for the key.
423
424 https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
425 documentation on setting up SiWA.
426
427 The synapse config will look like this:
428
429 ```yaml
430 - idp_id: apple
431 idp_name: Apple
432 issuer: "https://appleid.apple.com"
433 client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
434 client_auth_method: "client_secret_post"
435 client_secret_jwt_key:
436 key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
437 jwt_header:
438 alg: ES256
439 kid: "KEYIDCODE" # Set to the 10-char Key ID
440 jwt_payload:
441 iss: TEAMIDCODE # Set to the 10-char Team ID
442 scopes: ["name", "email", "openid"]
443 authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
444 user_mapping_provider:
445 config:
446 email_template: "{{ user.email }}"
447 ```
22 It is recommended to put a reverse proxy such as
33 [nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
44 [Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
5 [Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
6 [HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
5 [Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
6 [HAProxy](https://www.haproxy.org/) or
7 [relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
78 of doing so is that it means that you can expose the default https port
89 (443) to Matrix clients without needing to run Synapse with root
910 privileges.
161162 server matrix 127.0.0.1:8008
162163 ```
163164
165 ### Relayd
166
167 ```
168 table <webserver> { 127.0.0.1 }
169 table <matrixserver> { 127.0.0.1 }
170
171 http protocol "https" {
172 tls { no tlsv1.0, ciphers "HIGH" }
173 tls keypair "example.com"
174 match header set "X-Forwarded-For" value "$REMOTE_ADDR"
175 match header set "X-Forwarded-Proto" value "https"
176
177 # set CORS header for .well-known/matrix/server, .well-known/matrix/client
178 # httpd does not support setting headers, so do it here
179 match request path "/.well-known/matrix/*" tag "matrix-cors"
180 match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
181
182 pass quick path "/_matrix/*" forward to <matrixserver>
183 pass quick path "/_synapse/client/*" forward to <matrixserver>
184
185 # pass on non-matrix traffic to webserver
186 pass forward to <webserver>
187 }
188
189 relay "https_traffic" {
190 listen on egress port 443 tls
191 protocol "https"
192 forward to <matrixserver> port 8008 check tcp
193 forward to <webserver> port 8080 check tcp
194 }
195
196 http protocol "matrix" {
197 tls { no tlsv1.0, ciphers "HIGH" }
198 tls keypair "example.com"
199 block
200 pass quick path "/_matrix/*" forward to <matrixserver>
201 pass quick path "/_synapse/client/*" forward to <matrixserver>
202 }
203
204 relay "matrix_federation" {
205 listen on egress port 8448 tls
206 protocol "matrix"
207 forward to <matrixserver> port 8008 check tcp
208 }
209 ```
210
164211 ## Homeserver Configuration
165212
166213 You will also want to set `bind_addresses: ['127.0.0.1']` and
8888 # Whether to require authentication to retrieve profile data (avatars,
8989 # display names) of other users through the client API. Defaults to
9090 # 'false'. Note that profile data is also available via the federation
91 # API, so this setting is of limited value if federation is enabled on
92 # the server.
91 # API, unless allow_profile_lookup_over_federation is set to false.
9392 #
9493 #require_auth_for_profile_requests: true
9594
17791778 #
17801779 # client_id: Required. oauth2 client id to use.
17811780 #
1782 # client_secret: Required. oauth2 client secret to use.
1781 # client_secret: oauth2 client secret to use. May be omitted if
1782 # client_secret_jwt_key is given, or if client_auth_method is 'none'.
1783 #
1784 # client_secret_jwt_key: Alternative to client_secret: details of a key used
1785 # to create a JSON Web Token to be used as an OAuth2 client secret. If
1786 # given, must be a dictionary with the following properties:
1787 #
1788 # key: a pem-encoded signing key. Must be a suitable key for the
1789 # algorithm specified. Required unless 'key_file' is given.
1790 #
1791 # key_file: the path to file containing a pem-encoded signing key file.
1792 # Required unless 'key' is given.
1793 #
1794 # jwt_header: a dictionary giving properties to include in the JWT
1795 # header. Must include the key 'alg', giving the algorithm used to
1796 # sign the JWT, such as "ES256", using the JWA identifiers in
1797 # RFC7518.
1798 #
1799 # jwt_payload: an optional dictionary giving properties to include in
1800 # the JWT payload. Normally this should include an 'iss' key.
17831801 #
17841802 # client_auth_method: auth method to use when exchanging the token. Valid
17851803 # values are 'client_secret_basic' (default), 'client_secret_post' and
19001918 #
19011919 #- idp_id: github
19021920 # idp_name: Github
1903 # idp_brand: org.matrix.github
1921 # idp_brand: github
19041922 # discover: false
19051923 # issuer: "https://github.com/"
19061924 # client_id: "your-client-id" # TO BE FILLED
26262644
26272645
26282646
2629 # Local statistics collection. Used in populating the room directory.
2630 #
2631 # 'bucket_size' controls how large each statistics timeslice is. It can
2632 # be defined in a human readable short form -- e.g. "1d", "1y".
2633 #
2634 # 'retention' controls how long historical statistics will be kept for.
2635 # It can be defined in a human readable short form -- e.g. "1d", "1y".
2636 #
2637 #
2638 #stats:
2639 # enabled: true
2640 # bucket_size: 1d
2641 # retention: 1y
2647 # Settings for local room and user statistics collection. See
2648 # docs/room_and_user_statistics.md.
2649 #
2650 stats:
2651 # Uncomment the following to disable room and user statistics. Note that doing
2652 # so may cause certain features (such as the room directory) not to work
2653 # correctly.
2654 #
2655 #enabled: false
2656
2657 # The size of each timeslice in the room_stats_historical and
2658 # user_stats_historical tables, as a time period. Defaults to "1d".
2659 #
2660 #bucket_size: 1h
26422661
26432662
26442663 # Server Notices room configuration
1313 * An instance of `synapse.module_api.ModuleApi`.
1414
1515 It then implements methods which return a boolean to alter behavior in Synapse.
16 All the methods must be defined.
1617
1718 There's a generic method for checking every event (`check_event_for_spam`), as
1819 well as some specific methods:
2324 * `user_may_publish_room`
2425 * `check_username_for_spam`
2526 * `check_registration_for_spam`
27 * `check_media_file_for_spam`
2628
2729 The details of each of these methods (as well as their inputs and outputs)
2830 are documented in the `synapse.events.spamcheck.SpamChecker` class.
2931
3032 The `ModuleApi` class provides a way for the custom spam checker class to
3133 call back into the homeserver internals.
34
35 Additionally, a `parse_config` method is mandatory and receives the plugin config
36 dictionary. After parsing, It must return an object which will be
37 passed to `__init__` later.
3238
3339 ### Example
3440
4046 self.config = config
4147 self.api = api
4248
49 @staticmethod
50 def parse_config(config):
51 return config
52
4353 async def check_event_for_spam(self, foo):
4454 return False # allow all events
4555
5868 async def check_username_for_spam(self, user_profile):
5969 return False # allow all usernames
6070
61 async def check_registration_for_spam(self, email_threepid, username, request_info):
71 async def check_registration_for_spam(
72 self,
73 email_threepid,
74 username,
75 request_info,
76 auth_provider_id,
77 ):
6278 return RegistrationBehaviour.ALLOW # allow all registrations
6379
6480 async def check_media_file_for_spam(self, file_wrapper, file_info):
6868 synapse/util/async_helpers.py,
6969 synapse/util/caches,
7070 synapse/util/metrics.py,
71 synapse/util/macaroons.py,
7172 synapse/util/stringutils.py,
7273 tests/replication,
7374 tests/test_utils,
113114 ignore_missing_imports = True
114115
115116 [mypy-saml2.*]
116 ignore_missing_imports = True
117
118 [mypy-unpaddedbase64]
119117 ignore_missing_imports = True
120118
121119 [mypy-canonicaljson]
11 # Find linting errors in Synapse's default config file.
22 # Exits with 0 if there are no problems, or another code otherwise.
33
4 # cd to the root of the repository
5 cd `dirname $0`/..
6
7 # Restore backup of sample config upon script exit
8 trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
9
410 # Fix non-lowercase true/false values
511 sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
6 rm docs/sample_config.yaml.bak
712
813 # Check if anything changed
9 git diff --exit-code docs/sample_config.yaml
14 diff docs/sample_config.yaml docs/sample_config.yaml.bak
22
33 [check-manifest]
44 ignore =
5 .git-blame-ignore-revs
56 contrib
67 contrib/*
78 docs/*
1616 """
1717 from typing import Any, List, Optional, Type, Union
1818
19 class RedisProtocol:
19 from twisted.internet import protocol
20
21 class RedisProtocol(protocol.Protocol):
2022 def publish(self, channel: str, message: bytes): ...
2123 async def ping(self) -> None: ...
2224 async def set(
5153
5254 class ConnectionHandler: ...
5355
54 class RedisFactory:
56 class RedisFactory(protocol.ReconnectingClientFactory):
5557 continueTrying: bool
5658 handler: RedisProtocol
5759 pool: List[RedisProtocol]
4747 except ImportError:
4848 pass
4949
50 __version__ = "1.29.0"
50 __version__ = "1.30.0"
5151
5252 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
5353 # We import here so that we don't have to install a bunch of deps when
3838 from synapse.storage.databases.main.registration import TokenLookupResult
3939 from synapse.types import StateMap, UserID
4040 from synapse.util.caches.lrucache import LruCache
41 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
4142 from synapse.util.metrics import Measure
4243
4344 logger = logging.getLogger(__name__)
162163
163164 async def get_user_by_req(
164165 self,
165 request: Request,
166 request: SynapseRequest,
166167 allow_guest: bool = False,
167168 rights: str = "access",
168169 allow_expired: bool = False,
407408 raise _InvalidMacaroonException()
408409
409410 try:
410 user_id = self.get_user_id_from_macaroon(macaroon)
411 user_id = get_value_from_macaroon(macaroon, "user_id")
411412
412413 guest = False
413414 for caveat in macaroon.caveats:
415416 guest = True
416417
417418 self.validate_macaroon(macaroon, rights, user_id=user_id)
418 except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
419 except (
420 pymacaroons.exceptions.MacaroonException,
421 KeyError,
422 TypeError,
423 ValueError,
424 ):
419425 raise InvalidClientTokenError("Invalid macaroon passed.")
420426
421427 if rights == "access":
422428 self.token_cache[token] = (user_id, guest)
423429
424430 return user_id, guest
425
426 def get_user_id_from_macaroon(self, macaroon):
427 """Retrieve the user_id given by the caveats on the macaroon.
428
429 Does *not* validate the macaroon.
430
431 Args:
432 macaroon (pymacaroons.Macaroon): The macaroon to validate
433
434 Returns:
435 (str) user id
436
437 Raises:
438 InvalidClientCredentialsError if there is no user_id caveat in the
439 macaroon
440 """
441 user_prefix = "user_id = "
442 for caveat in macaroon.caveats:
443 if caveat.caveat_id.startswith(user_prefix):
444 return caveat.caveat_id[len(user_prefix) :]
445 raise InvalidClientTokenError("No user caveat in macaroon")
446431
447432 def validate_macaroon(self, macaroon, type_string, user_id):
448433 """
464449 v.satisfy_exact("type = " + type_string)
465450 v.satisfy_exact("user_id = %s" % user_id)
466451 v.satisfy_exact("guest = true")
467 v.satisfy_general(self._verify_expiry)
452 satisfy_expiry(v, self.clock.time_msec)
468453
469454 # access_tokens include a nonce for uniqueness: any value is acceptable
470455 v.satisfy_general(lambda c: c.startswith("nonce = "))
471456
472457 v.verify(macaroon, self._macaroon_secret_key)
473
474 def _verify_expiry(self, caveat):
475 prefix = "time < "
476 if not caveat.startswith(prefix):
477 return False
478 expiry = int(caveat[len(prefix) :])
479 now = self.hs.get_clock().time_msec()
480 return now < expiry
481458
482459 def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
483460 token = self.get_access_token_from_request(request)
8989 self.clock = hs.get_clock()
9090
9191 self.protocol_meta_cache = ResponseCache(
92 hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
92 hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
9393 ) # type: ResponseCache[Tuple[str, str]]
9494
9595 async def query_user(self, service, user_id):
211211
212212 @classmethod
213213 def read_file(cls, file_path, config_name):
214 cls.check_file(file_path, config_name)
215 with open(file_path) as file_stream:
216 return file_stream.read()
214 """Deprecated: call read_file directly"""
215 return read_file(file_path, (config_name,))
217216
218217 def read_template(self, filename: str) -> jinja2.Template:
219218 """Load a template file from disk.
893892 return self._get_instance(key)
894893
895894
896 __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
895 def read_file(file_path: Any, config_path: Iterable[str]) -> str:
896 """Check the given file exists, and read it into a string
897
898 If it does not, emit an error indicating the problem
899
900 Args:
901 file_path: the file to be read
902 config_path: where in the configuration file_path came from, so that a useful
903 error can be emitted if it does not exist.
904 Returns:
905 content of the file.
906 Raises:
907 ConfigError if there is a problem reading the file.
908 """
909 if not isinstance(file_path, str):
910 raise ConfigError("%r is not a string", config_path)
911
912 try:
913 os.stat(file_path)
914 with open(file_path) as file_stream:
915 return file_stream.read()
916 except OSError as e:
917 raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
918
919
920 __all__ = [
921 "Config",
922 "RootConfig",
923 "ShardedWorkerHandlingConfig",
924 "RoutableShardedWorkerHandlingConfig",
925 "read_file",
926 ]
151151
152152 class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
153153 def get_instance(self, key: str) -> str: ...
154
155 def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
2020 from string import Template
2121
2222 import yaml
23 from zope.interface import implementer
2324
2425 from twisted.logger import (
26 ILogObserver,
2527 LogBeginner,
2628 STDLibLogObserver,
2729 eventAsText,
226228
227229 threadlocal = threading.local()
228230
229 def _log(event):
231 @implementer(ILogObserver)
232 def _log(event: dict) -> None:
230233 if "log_text" in event:
231234 if event["log_text"].startswith("DNSDatagramProtocol starting on "):
232235 return
1414 # limitations under the License.
1515
1616 from collections import Counter
17 from typing import Iterable, Optional, Tuple, Type
17 from typing import Iterable, Mapping, Optional, Tuple, Type
1818
1919 import attr
2020
2424 from synapse.util.module_loader import load_module
2525 from synapse.util.stringutils import parse_and_validate_mxc_uri
2626
27 from ._base import Config, ConfigError
27 from ._base import Config, ConfigError, read_file
2828
2929 DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
3030
9696 #
9797 # client_id: Required. oauth2 client id to use.
9898 #
99 # client_secret: Required. oauth2 client secret to use.
99 # client_secret: oauth2 client secret to use. May be omitted if
100 # client_secret_jwt_key is given, or if client_auth_method is 'none'.
101 #
102 # client_secret_jwt_key: Alternative to client_secret: details of a key used
103 # to create a JSON Web Token to be used as an OAuth2 client secret. If
104 # given, must be a dictionary with the following properties:
105 #
106 # key: a pem-encoded signing key. Must be a suitable key for the
107 # algorithm specified. Required unless 'key_file' is given.
108 #
109 # key_file: the path to file containing a pem-encoded signing key file.
110 # Required unless 'key' is given.
111 #
112 # jwt_header: a dictionary giving properties to include in the JWT
113 # header. Must include the key 'alg', giving the algorithm used to
114 # sign the JWT, such as "ES256", using the JWA identifiers in
115 # RFC7518.
116 #
117 # jwt_payload: an optional dictionary giving properties to include in
118 # the JWT payload. Normally this should include an 'iss' key.
100119 #
101120 # client_auth_method: auth method to use when exchanging the token. Valid
102121 # values are 'client_secret_basic' (default), 'client_secret_post' and
217236 #
218237 #- idp_id: github
219238 # idp_name: Github
220 # idp_brand: org.matrix.github
239 # idp_brand: github
221240 # discover: false
222241 # issuer: "https://github.com/"
223242 # client_id: "your-client-id" # TO BE FILLED
239258 # jsonschema definition of the configuration settings for an oidc identity provider
240259 OIDC_PROVIDER_CONFIG_SCHEMA = {
241260 "type": "object",
242 "required": ["issuer", "client_id", "client_secret"],
261 "required": ["issuer", "client_id"],
243262 "properties": {
244263 "idp_id": {
245264 "type": "string",
252271 "idp_icon": {"type": "string"},
253272 "idp_brand": {
254273 "type": "string",
255 # MSC2758-style namespaced identifier
274 "minLength": 1,
275 "maxLength": 255,
276 "pattern": "^[a-z][a-z0-9_.-]*$",
277 },
278 "idp_unstable_brand": {
279 "type": "string",
256280 "minLength": 1,
257281 "maxLength": 255,
258282 "pattern": "^[a-z][a-z0-9_.-]*$",
261285 "issuer": {"type": "string"},
262286 "client_id": {"type": "string"},
263287 "client_secret": {"type": "string"},
288 "client_secret_jwt_key": {
289 "type": "object",
290 "required": ["jwt_header"],
291 "oneOf": [
292 {"required": ["key"]},
293 {"required": ["key_file"]},
294 ],
295 "properties": {
296 "key": {"type": "string"},
297 "key_file": {"type": "string"},
298 "jwt_header": {
299 "type": "object",
300 "required": ["alg"],
301 "properties": {
302 "alg": {"type": "string"},
303 },
304 "additionalProperties": {"type": "string"},
305 },
306 "jwt_payload": {
307 "type": "object",
308 "additionalProperties": {"type": "string"},
309 },
310 },
311 },
264312 "client_auth_method": {
265313 "type": "string",
266314 # the following list is the same as the keys of
403451 "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
404452 ) from e
405453
454 client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
455 client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
456 if client_secret_jwt_key_config is not None:
457 keyfile = client_secret_jwt_key_config.get("key_file")
458 if keyfile:
459 key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
460 else:
461 key = client_secret_jwt_key_config["key"]
462 client_secret_jwt_key = OidcProviderClientSecretJwtKey(
463 key=key,
464 jwt_header=client_secret_jwt_key_config["jwt_header"],
465 jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
466 )
467
406468 return OidcProviderConfig(
407469 idp_id=idp_id,
408470 idp_name=oidc_config.get("idp_name", "OIDC"),
409471 idp_icon=idp_icon,
410472 idp_brand=oidc_config.get("idp_brand"),
473 unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
411474 discover=oidc_config.get("discover", True),
412475 issuer=oidc_config["issuer"],
413476 client_id=oidc_config["client_id"],
414 client_secret=oidc_config["client_secret"],
477 client_secret=oidc_config.get("client_secret"),
478 client_secret_jwt_key=client_secret_jwt_key,
415479 client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
416480 scopes=oidc_config.get("scopes", ["openid"]),
417481 authorization_endpoint=oidc_config.get("authorization_endpoint"),
427491
428492
429493 @attr.s(slots=True, frozen=True)
494 class OidcProviderClientSecretJwtKey:
495 # a pem-encoded signing key
496 key = attr.ib(type=str)
497
498 # properties to include in the JWT header
499 jwt_header = attr.ib(type=Mapping[str, str])
500
501 # properties to include in the JWT payload.
502 jwt_payload = attr.ib(type=Mapping[str, str])
503
504
505 @attr.s(slots=True, frozen=True)
430506 class OidcProviderConfig:
431507 # a unique identifier for this identity provider. Used in the 'user_external_ids'
432508 # table, as well as the query/path parameter used in the login protocol.
441517 # Optional brand identifier for this IdP.
442518 idp_brand = attr.ib(type=Optional[str])
443519
520 # Optional brand identifier for the unstable API (see MSC2858).
521 unstable_idp_brand = attr.ib(type=Optional[str])
522
444523 # whether the OIDC discovery mechanism is used to discover endpoints
445524 discover = attr.ib(type=bool)
446525
451530 # oauth2 client id to use
452531 client_id = attr.ib(type=str)
453532
454 # oauth2 client secret to use
455 client_secret = attr.ib(type=str)
533 # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
534 # a secret.
535 client_secret = attr.ib(type=Optional[str])
536
537 # key to use to construct a JWT to use as a client secret. May be `None` if
538 # `client_secret` is set.
539 client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
456540
457541 # auth method to use when exchanging the token.
458542 # Valid values are 'client_secret_basic', 'client_secret_post' and
840840 # Whether to require authentication to retrieve profile data (avatars,
841841 # display names) of other users through the client API. Defaults to
842842 # 'false'. Note that profile data is also available via the federation
843 # API, so this setting is of limited value if federation is enabled on
844 # the server.
843 # API, unless allow_profile_lookup_over_federation is set to false.
845844 #
846845 #require_auth_for_profile_requests: true
847846
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 import sys
15 import logging
1616
1717 from ._base import Config
18
19 ROOM_STATS_DISABLED_WARN = """\
20 WARNING: room/user statistics have been disabled via the stats.enabled
21 configuration setting. This means that certain features (such as the room
22 directory) will not operate correctly. Future versions of Synapse may ignore
23 this setting.
24
25 To fix this warning, remove the stats.enabled setting from your configuration
26 file.
27 --------------------------------------------------------------------------------"""
28
29 logger = logging.getLogger(__name__)
1830
1931
2032 class StatsConfig(Config):
2739 def read_config(self, config, **kwargs):
2840 self.stats_enabled = True
2941 self.stats_bucket_size = 86400 * 1000
30 self.stats_retention = sys.maxsize
3142 stats_config = config.get("stats", None)
3243 if stats_config:
3344 self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
3445 self.stats_bucket_size = self.parse_duration(
3546 stats_config.get("bucket_size", "1d")
3647 )
37 self.stats_retention = self.parse_duration(
38 stats_config.get("retention", "%ds" % (sys.maxsize,))
39 )
48 if not self.stats_enabled:
49 logger.warning(ROOM_STATS_DISABLED_WARN)
4050
4151 def generate_config_section(self, config_dir_path, server_name, **kwargs):
4252 return """
43 # Local statistics collection. Used in populating the room directory.
53 # Settings for local room and user statistics collection. See
54 # docs/room_and_user_statistics.md.
4455 #
45 # 'bucket_size' controls how large each statistics timeslice is. It can
46 # be defined in a human readable short form -- e.g. "1d", "1y".
47 #
48 # 'retention' controls how long historical statistics will be kept for.
49 # It can be defined in a human readable short form -- e.g. "1d", "1y".
50 #
51 #
52 #stats:
53 # enabled: true
54 # bucket_size: 1d
55 # retention: 1y
56 stats:
57 # Uncomment the following to disable room and user statistics. Note that doing
58 # so may cause certain features (such as the room directory) not to work
59 # correctly.
60 #
61 #enabled: false
62
63 # The size of each timeslice in the room_stats_historical and
64 # user_stats_historical tables, as a time period. Defaults to "1d".
65 #
66 #bucket_size: 1h
5667 """
1414 # limitations under the License.
1515
1616 import inspect
17 import logging
1718 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1819
1920 from synapse.rest.media.v1._base import FileInfo
2526 if TYPE_CHECKING:
2627 import synapse.events
2728 import synapse.server
29
30 logger = logging.getLogger(__name__)
2831
2932
3033 class SpamChecker:
189192 email_threepid: Optional[dict],
190193 username: Optional[str],
191194 request_info: Collection[Tuple[str, str]],
195 auth_provider_id: Optional[str] = None,
192196 ) -> RegistrationBehaviour:
193197 """Checks if we should allow the given registration request.
194198
197201 username: The request user name, if any
198202 request_info: List of tuples of user agent and IP that
199203 were used during the registration process.
204 auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
205 "cas". If any. Note this does not include users registered
206 via a password provider.
200207
201208 Returns:
202209 Enum for how the request should be handled
207214 # spam checker
208215 checker = getattr(spam_checker, "check_registration_for_spam", None)
209216 if checker:
210 behaviour = await maybe_awaitable(
211 checker(email_threepid, username, request_info)
212 )
217 # Provide auth_provider_id if the function supports it
218 checker_args = inspect.signature(checker)
219 if len(checker_args.parameters) == 4:
220 d = checker(
221 email_threepid,
222 username,
223 request_info,
224 auth_provider_id,
225 )
226 elif len(checker_args.parameters) == 3:
227 d = checker(email_threepid, username, request_info)
228 else:
229 logger.error(
230 "Invalid signature for %s.check_registration_for_spam. Denying registration",
231 spam_checker.__module__,
232 )
233 return RegistrationBehaviour.DENY
234
235 behaviour = await maybe_awaitable(d)
213236 assert isinstance(behaviour, RegistrationBehaviour)
214237 if behaviour != RegistrationBehaviour.ALLOW:
215238 return behaviour
2121 Awaitable,
2222 Callable,
2323 Dict,
24 Iterable,
2425 List,
2526 Optional,
2627 Tuple,
8990 "Time taken to process an event",
9091 )
9192
92
93 last_pdu_age_metric = Gauge(
94 "synapse_federation_last_received_pdu_age",
95 "The age (in seconds) of the last PDU successfully received from the given domain",
93 last_pdu_ts_metric = Gauge(
94 "synapse_federation_last_received_pdu_time",
95 "The timestamp of the last PDU which was successfully received from the given domain",
9696 labelnames=("server_name",),
9797 )
9898
9999
100100 class FederationServer(FederationBase):
101 def __init__(self, hs):
101 def __init__(self, hs: "HomeServer"):
102102 super().__init__(hs)
103103
104104 self.auth = hs.get_auth()
111111 # with FederationHandlerRegistry.
112112 hs.get_directory_handler()
113113
114 self._federation_ratelimiter = hs.get_federation_ratelimiter()
115
116114 self._server_linearizer = Linearizer("fed_server")
117 self._transaction_linearizer = Linearizer("fed_txn_handler")
115
116 # origins that we are currently processing a transaction from.
117 # a dict from origin to txn id.
118 self._active_transactions = {} # type: Dict[str, str]
118119
119120 # We cache results for transaction with the same ID
120121 self._transaction_resp_cache = ResponseCache(
121 hs, "fed_txn_handler", timeout_ms=30000
122 hs.get_clock(), "fed_txn_handler", timeout_ms=30000
122123 ) # type: ResponseCache[Tuple[str, str]]
123124
124125 self.transaction_actions = TransactionActions(self.store)
128129 # We cache responses to state queries, as they take a while and often
129130 # come in waves.
130131 self._state_resp_cache = ResponseCache(
131 hs, "state_resp", timeout_ms=30000
132 hs.get_clock(), "state_resp", timeout_ms=30000
132133 ) # type: ResponseCache[Tuple[str, str]]
133134 self._state_ids_resp_cache = ResponseCache(
134 hs, "state_ids_resp", timeout_ms=30000
135 hs.get_clock(), "state_ids_resp", timeout_ms=30000
135136 ) # type: ResponseCache[Tuple[str, str]]
136137
137138 self._federation_metrics_domains = (
167168 raise Exception("Transaction missing transaction_id")
168169
169170 logger.debug("[%s] Got transaction", transaction_id)
171
172 # Reject malformed transactions early: reject if too many PDUs/EDUs
173 if len(transaction.pdus) > 50 or ( # type: ignore
174 hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
175 ):
176 logger.info("Transaction PDU or EDU count too large. Returning 400")
177 return 400, {}
178
179 # we only process one transaction from each origin at a time. We need to do
180 # this check here, rather than in _on_incoming_transaction_inner so that we
181 # don't cache the rejection in _transaction_resp_cache (so that if the txn
182 # arrives again later, we can process it).
183 current_transaction = self._active_transactions.get(origin)
184 if current_transaction and current_transaction != transaction_id:
185 logger.warning(
186 "Received another txn %s from %s while still processing %s",
187 transaction_id,
188 origin,
189 current_transaction,
190 )
191 return 429, {
192 "errcode": Codes.UNKNOWN,
193 "error": "Too many concurrent transactions",
194 }
195
196 # CRITICAL SECTION: we must now not await until we populate _active_transactions
197 # in _on_incoming_transaction_inner.
170198
171199 # We wrap in a ResponseCache so that we de-duplicate retried
172200 # transactions.
181209 async def _on_incoming_transaction_inner(
182210 self, origin: str, transaction: Transaction, request_time: int
183211 ) -> Tuple[int, Dict[str, Any]]:
184 # Use a linearizer to ensure that transactions from a remote are
185 # processed in order.
186 with await self._transaction_linearizer.queue(origin):
187 # We rate limit here *after* we've queued up the incoming requests,
188 # so that we don't fill up the ratelimiter with blocked requests.
189 #
190 # This is important as the ratelimiter allows N concurrent requests
191 # at a time, and only starts ratelimiting if there are more requests
192 # than that being processed at a time. If we queued up requests in
193 # the linearizer/response cache *after* the ratelimiting then those
194 # queued up requests would count as part of the allowed limit of N
195 # concurrent requests.
196 with self._federation_ratelimiter.ratelimit(origin) as d:
197 await d
198
199 result = await self._handle_incoming_transaction(
200 origin, transaction, request_time
201 )
202
203 return result
212 # CRITICAL SECTION: the first thing we must do (before awaiting) is
213 # add an entry to _active_transactions.
214 assert origin not in self._active_transactions
215 self._active_transactions[origin] = transaction.transaction_id # type: ignore
216
217 try:
218 result = await self._handle_incoming_transaction(
219 origin, transaction, request_time
220 )
221 return result
222 finally:
223 del self._active_transactions[origin]
204224
205225 async def _handle_incoming_transaction(
206226 self, origin: str, transaction: Transaction, request_time: int
225245 return response
226246
227247 logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
228
229 # Reject if PDU count > 50 or EDU count > 100
230 if len(transaction.pdus) > 50 or ( # type: ignore
231 hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
232 ):
233
234 logger.info("Transaction PDU or EDU count too large. Returning 400")
235
236 response = {}
237 await self.transaction_actions.set_response(
238 origin, transaction, 400, response
239 )
240 return 400, response
241248
242249 # We process PDUs and EDUs in parallel. This is important as we don't
243250 # want to block things like to device messages from reaching clients
334341 # impose a limit to avoid going too crazy with ram/cpu.
335342
336343 async def process_pdus_for_room(room_id: str):
337 logger.debug("Processing PDUs for %s", room_id)
338 try:
339 await self.check_server_matches_acl(origin_host, room_id)
340 except AuthError as e:
341 logger.warning("Ignoring PDUs for room %s from banned server", room_id)
344 with nested_logging_context(room_id):
345 logger.debug("Processing PDUs for %s", room_id)
346
347 try:
348 await self.check_server_matches_acl(origin_host, room_id)
349 except AuthError as e:
350 logger.warning(
351 "Ignoring PDUs for room %s from banned server", room_id
352 )
353 for pdu in pdus_by_room[room_id]:
354 event_id = pdu.event_id
355 pdu_results[event_id] = e.error_dict()
356 return
357
342358 for pdu in pdus_by_room[room_id]:
343 event_id = pdu.event_id
344 pdu_results[event_id] = e.error_dict()
345 return
346
347 for pdu in pdus_by_room[room_id]:
348 event_id = pdu.event_id
349 with pdu_process_time.time():
350 with nested_logging_context(event_id):
351 try:
352 await self._handle_received_pdu(origin, pdu)
353 pdu_results[event_id] = {}
354 except FederationError as e:
355 logger.warning("Error handling PDU %s: %s", event_id, e)
356 pdu_results[event_id] = {"error": str(e)}
357 except Exception as e:
358 f = failure.Failure()
359 pdu_results[event_id] = {"error": str(e)}
360 logger.error(
361 "Failed to handle PDU %s",
362 event_id,
363 exc_info=(f.type, f.value, f.getTracebackObject()),
364 )
359 pdu_results[pdu.event_id] = await process_pdu(pdu)
360
361 async def process_pdu(pdu: EventBase) -> JsonDict:
362 event_id = pdu.event_id
363 with pdu_process_time.time():
364 with nested_logging_context(event_id):
365 try:
366 await self._handle_received_pdu(origin, pdu)
367 return {}
368 except FederationError as e:
369 logger.warning("Error handling PDU %s: %s", event_id, e)
370 return {"error": str(e)}
371 except Exception as e:
372 f = failure.Failure()
373 logger.error(
374 "Failed to handle PDU %s",
375 event_id,
376 exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
377 )
378 return {"error": str(e)}
365379
366380 await concurrently_execute(
367381 process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
368382 )
369383
370384 if newest_pdu_ts and origin in self._federation_metrics_domains:
371 newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
372 last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
385 last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
373386
374387 return pdu_results
375388
447460
448461 async def _on_state_ids_request_compute(self, room_id, event_id):
449462 state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
450 auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
463 auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
451464 return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
452465
453466 async def _on_context_state_request_compute(
454467 self, room_id: str, event_id: str
455468 ) -> Dict[str, list]:
456469 if event_id:
457 pdus = await self.handler.get_state_for_pdu(room_id, event_id)
470 pdus = await self.handler.get_state_for_pdu(
471 room_id, event_id
472 ) # type: Iterable[EventBase]
458473 else:
459474 pdus = (await self.state.get_current_state(room_id)).values()
460475
461 auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
476 auth_chain = await self.store.get_auth_chain(
477 room_id, [pdu.event_id for pdu in pdus]
478 )
462479
463480 return {
464481 "pdus": [pdu.get_pdu_json() for pdu in pdus],
862879 self.edu_handlers = (
863880 {}
864881 ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
865 self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
882 self.query_handlers = (
883 {}
884 ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
866885
867886 # Map from type to instance names that we should route EDU handling to.
868887 # We randomly choose one instance from the list to route to for each new
896915 self.edu_handlers[edu_type] = handler
897916
898917 def register_query_handler(
899 self, query_type: str, handler: Callable[[dict], defer.Deferred]
918 self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
900919 ):
901920 """Sets the handler callable that will be used to handle an incoming
902921 federation query of the given type.
969988 # Oh well, let's just log and move on.
970989 logger.warning("No handler registered for EDU type %s", edu_type)
971990
972 async def on_query(self, query_type: str, args: dict):
991 async def on_query(self, query_type: str, args: dict) -> JsonDict:
973992 handler = self.query_handlers.get(query_type)
974993 if handler:
975994 return await handler(args)
1616 import logging
1717 from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
1818
19 import attr
1920 from prometheus_client import Counter
2021
2122 from synapse.api.errors import (
9293 self._destination = destination
9394 self.transmission_loop_running = False
9495
96 # Flag to signal to any running transmission loop that there is new data
97 # queued up to be sent.
98 self._new_data_to_send = False
99
95100 # True whilst we are sending events that the remote homeserver missed
96101 # because it was unreachable. We start in this state so we can perform
97102 # catch-up at startup.
107112 # destination (we are the only updater so this is safe)
108113 self._last_successful_stream_ordering = None # type: Optional[int]
109114
110 # a list of pending PDUs
115 # a queue of pending PDUs
111116 self._pending_pdus = [] # type: List[EventBase]
112117
113118 # XXX this is never actually used: see
207212 transaction in the background.
208213 """
209214
215 # Mark that we (may) have new things to send, so that any running
216 # transmission loop will recheck whether there is stuff to send.
217 self._new_data_to_send = True
218
210219 if self.transmission_loop_running:
211220 # XXX: this can get stuck on by a never-ending
212221 # request at which point pending_pdus just keeps growing.
249258
250259 pending_pdus = []
251260 while True:
252 # We have to keep 2 free slots for presence and rr_edus
253 limit = MAX_EDUS_PER_TRANSACTION - 2
254
255 device_update_edus, dev_list_id = await self._get_device_update_edus(
256 limit
257 )
258
259 limit -= len(device_update_edus)
260
261 (
262 to_device_edus,
263 device_stream_id,
264 ) = await self._get_to_device_message_edus(limit)
265
266 pending_edus = device_update_edus + to_device_edus
267
268 # BEGIN CRITICAL SECTION
269 #
270 # In order to avoid a race condition, we need to make sure that
271 # the following code (from popping the queues up to the point
272 # where we decide if we actually have any pending messages) is
273 # atomic - otherwise new PDUs or EDUs might arrive in the
274 # meantime, but not get sent because we hold the
275 # transmission_loop_running flag.
276
277 pending_pdus = self._pending_pdus
278
279 # We can only include at most 50 PDUs per transactions
280 pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
281
282 pending_edus.extend(self._get_rr_edus(force_flush=False))
283 pending_presence = self._pending_presence
284 self._pending_presence = {}
285 if pending_presence:
286 pending_edus.append(
287 Edu(
288 origin=self._server_name,
289 destination=self._destination,
290 edu_type="m.presence",
291 content={
292 "push": [
293 format_user_presence_state(
294 presence, self._clock.time_msec()
295 )
296 for presence in pending_presence.values()
297 ]
298 },
261 self._new_data_to_send = False
262
263 async with _TransactionQueueManager(self) as (
264 pending_pdus,
265 pending_edus,
266 ):
267 if not pending_pdus and not pending_edus:
268 logger.debug("TX [%s] Nothing to send", self._destination)
269
270 # If we've gotten told about new things to send during
271 # checking for things to send, we try looking again.
272 # Otherwise new PDUs or EDUs might arrive in the meantime,
273 # but not get sent because we hold the
274 # `transmission_loop_running` flag.
275 if self._new_data_to_send:
276 continue
277 else:
278 return
279
280 if pending_pdus:
281 logger.debug(
282 "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
283 self._destination,
284 len(pending_pdus),
299285 )
286
287 await self._transaction_manager.send_new_transaction(
288 self._destination, pending_pdus, pending_edus
300289 )
301290
302 pending_edus.extend(
303 self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
304 )
305 while (
306 len(pending_edus) < MAX_EDUS_PER_TRANSACTION
307 and self._pending_edus_keyed
308 ):
309 _, val = self._pending_edus_keyed.popitem()
310 pending_edus.append(val)
311
312 if pending_pdus:
313 logger.debug(
314 "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
315 self._destination,
316 len(pending_pdus),
317 )
318
319 if not pending_pdus and not pending_edus:
320 logger.debug("TX [%s] Nothing to send", self._destination)
321 self._last_device_stream_id = device_stream_id
322 return
323
324 # if we've decided to send a transaction anyway, and we have room, we
325 # may as well send any pending RRs
326 if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
327 pending_edus.extend(self._get_rr_edus(force_flush=True))
328
329 # END CRITICAL SECTION
330
331 success = await self._transaction_manager.send_new_transaction(
332 self._destination, pending_pdus, pending_edus
333 )
334 if success:
335291 sent_transactions_counter.inc()
336292 sent_edus_counter.inc(len(pending_edus))
337293 for edu in pending_edus:
338294 sent_edus_by_type.labels(edu.edu_type).inc()
339 # Remove the acknowledged device messages from the database
340 # Only bother if we actually sent some device messages
341 if to_device_edus:
342 await self._store.delete_device_msgs_for_remote(
343 self._destination, device_stream_id
344 )
345
346 # also mark the device updates as sent
347 if device_update_edus:
348 logger.info(
349 "Marking as sent %r %r", self._destination, dev_list_id
350 )
351 await self._store.mark_as_sent_devices_by_remote(
352 self._destination, dev_list_id
353 )
354
355 self._last_device_stream_id = device_stream_id
356 self._last_device_list_stream_id = dev_list_id
357
358 if pending_pdus:
359 # we sent some PDUs and it was successful, so update our
360 # last_successful_stream_ordering in the destinations table.
361 final_pdu = pending_pdus[-1]
362 last_successful_stream_ordering = (
363 final_pdu.internal_metadata.stream_ordering
364 )
365 assert last_successful_stream_ordering
366 await self._store.set_destination_last_successful_stream_ordering(
367 self._destination, last_successful_stream_ordering
368 )
369 else:
370 break
295
371296 except NotRetryingDestination as e:
372297 logger.debug(
373298 "TX [%s] not ready for retry yet (next retry at %s) - "
400325 self._pending_presence = {}
401326 self._pending_rrs = {}
402327
403 self._start_catching_up()
328 self._start_catching_up()
404329 except FederationDeniedError as e:
405330 logger.info(e)
406331 except HttpResponseException as e:
411336 e,
412337 )
413338
414 self._start_catching_up()
415339 except RequestSendFailed as e:
416340 logger.warning(
417341 "TX [%s] Failed to send transaction: %s", self._destination, e
421345 logger.info(
422346 "Failed to send event %s to %s", p.event_id, self._destination
423347 )
424
425 self._start_catching_up()
426348 except Exception:
427349 logger.exception("TX [%s] Failed to send transaction", self._destination)
428350 for p in pending_pdus:
429351 logger.info(
430352 "Failed to send event %s to %s", p.event_id, self._destination
431353 )
432
433 self._start_catching_up()
434354 finally:
435355 # We want to be *very* sure we clear this after we stop processing
436356 self.transmission_loop_running = False
498418 rooms = [p.room_id for p in catchup_pdus]
499419 logger.info("Catching up rooms to %s: %r", self._destination, rooms)
500420
501 success = await self._transaction_manager.send_new_transaction(
421 await self._transaction_manager.send_new_transaction(
502422 self._destination, catchup_pdus, []
503423 )
504
505 if not success:
506 return
507424
508425 sent_transactions_counter.inc()
509426 final_pdu = catchup_pdus[-1]
583500 """
584501 self._catching_up = True
585502 self._pending_pdus = []
503
504
505 @attr.s(slots=True)
506 class _TransactionQueueManager:
507 """A helper async context manager for pulling stuff off the queues and
508 tracking what was last successfully sent, etc.
509 """
510
511 queue = attr.ib(type=PerDestinationQueue)
512
513 _device_stream_id = attr.ib(type=Optional[int], default=None)
514 _device_list_id = attr.ib(type=Optional[int], default=None)
515 _last_stream_ordering = attr.ib(type=Optional[int], default=None)
516 _pdus = attr.ib(type=List[EventBase], factory=list)
517
518 async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
519 # First we calculate the EDUs we want to send, if any.
520
521 # We start by fetching device related EDUs, i.e device updates and to
522 # device messages. We have to keep 2 free slots for presence and rr_edus.
523 limit = MAX_EDUS_PER_TRANSACTION - 2
524
525 device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
526 limit
527 )
528
529 if device_update_edus:
530 self._device_list_id = dev_list_id
531 else:
532 self.queue._last_device_list_stream_id = dev_list_id
533
534 limit -= len(device_update_edus)
535
536 (
537 to_device_edus,
538 device_stream_id,
539 ) = await self.queue._get_to_device_message_edus(limit)
540
541 if to_device_edus:
542 self._device_stream_id = device_stream_id
543 else:
544 self.queue._last_device_stream_id = device_stream_id
545
546 pending_edus = device_update_edus + to_device_edus
547
548 # Now add the read receipt EDU.
549 pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
550
551 # And presence EDU.
552 if self.queue._pending_presence:
553 pending_edus.append(
554 Edu(
555 origin=self.queue._server_name,
556 destination=self.queue._destination,
557 edu_type="m.presence",
558 content={
559 "push": [
560 format_user_presence_state(
561 presence, self.queue._clock.time_msec()
562 )
563 for presence in self.queue._pending_presence.values()
564 ]
565 },
566 )
567 )
568 self.queue._pending_presence = {}
569
570 # Finally add any other types of EDUs if there is room.
571 pending_edus.extend(
572 self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
573 )
574 while (
575 len(pending_edus) < MAX_EDUS_PER_TRANSACTION
576 and self.queue._pending_edus_keyed
577 ):
578 _, val = self.queue._pending_edus_keyed.popitem()
579 pending_edus.append(val)
580
581 # Now we look for any PDUs to send, by getting up to 50 PDUs from the
582 # queue
583 self._pdus = self.queue._pending_pdus[:50]
584
585 if not self._pdus and not pending_edus:
586 return [], []
587
588 # if we've decided to send a transaction anyway, and we have room, we
589 # may as well send any pending RRs
590 if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
591 pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
592
593 if self._pdus:
594 self._last_stream_ordering = self._pdus[
595 -1
596 ].internal_metadata.stream_ordering
597 assert self._last_stream_ordering
598
599 return self._pdus, pending_edus
600
601 async def __aexit__(self, exc_type, exc, tb):
602 if exc_type is not None:
603 # Failed to send transaction, so we bail out.
604 return
605
606 # Successfully sent transactions, so we remove pending PDUs from the queue
607 if self._pdus:
608 self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
609
610 # Succeeded to send the transaction so we record where we have sent up
611 # to in the various streams
612
613 if self._device_stream_id:
614 await self.queue._store.delete_device_msgs_for_remote(
615 self.queue._destination, self._device_stream_id
616 )
617 self.queue._last_device_stream_id = self._device_stream_id
618
619 # also mark the device updates as sent
620 if self._device_list_id:
621 logger.info(
622 "Marking as sent %r %r", self.queue._destination, self._device_list_id
623 )
624 await self.queue._store.mark_as_sent_devices_by_remote(
625 self.queue._destination, self._device_list_id
626 )
627 self.queue._last_device_list_stream_id = self._device_list_id
628
629 if self._last_stream_ordering:
630 # we sent some PDUs and it was successful, so update our
631 # last_successful_stream_ordering in the destinations table.
632 await self.queue._store.set_destination_last_successful_stream_ordering(
633 self.queue._destination, self._last_stream_ordering
634 )
3535
3636 logger = logging.getLogger(__name__)
3737
38 last_pdu_age_metric = Gauge(
39 "synapse_federation_last_sent_pdu_age",
40 "The age (in seconds) of the last PDU successfully sent to the given domain",
38 last_pdu_ts_metric = Gauge(
39 "synapse_federation_last_sent_pdu_time",
40 "The timestamp of the last PDU which was successfully sent to the given domain",
4141 labelnames=("server_name",),
4242 )
4343
6868 destination: str,
6969 pdus: List[EventBase],
7070 edus: List[Edu],
71 ) -> bool:
71 ) -> None:
7272 """
7373 Args:
7474 destination: The destination to send to (e.g. 'example.org')
7575 pdus: In-order list of PDUs to send
7676 edus: List of EDUs to send
77
78 Returns:
79 True iff the transaction was successful
8077 """
8178
8279 # Make a transaction-sending opentracing span. This span follows on from
9592 edu.strip_context()
9693
9794 with start_active_span_follows_from("send_transaction", span_contexts):
98 success = True
99
10095 logger.debug("TX [%s] _attempt_new_transaction", destination)
10196
10297 txn_id = str(self._next_txn_id)
151146 response = await self._transport_layer.send_transaction(
152147 transaction, json_data_cb
153148 )
154 code = 200
155149 except HttpResponseException as e:
156150 code = e.code
157151 response = e.response
158152
159 if e.code in (401, 404, 429) or 500 <= e.code:
160 logger.info(
161 "TX [%s] {%s} got %d response", destination, txn_id, code
162 )
163 raise e
153 set_tag(tags.ERROR, True)
164154
165 logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
155 logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
156 raise
166157
167 if code == 200:
168 for e_id, r in response.get("pdus", {}).items():
169 if "error" in r:
170 logger.warning(
171 "TX [%s] {%s} Remote returned error for %s: %s",
172 destination,
173 txn_id,
174 e_id,
175 r,
176 )
177 else:
178 for p in pdus:
158 logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
159
160 for e_id, r in response.get("pdus", {}).items():
161 if "error" in r:
179162 logger.warning(
180 "TX [%s] {%s} Failed to send event %s",
163 "TX [%s] {%s} Remote returned error for %s: %s",
181164 destination,
182165 txn_id,
183 p.event_id,
166 e_id,
167 r,
184168 )
185 success = False
186169
187 if success and pdus and destination in self._federation_metrics_domains:
170 if pdus and destination in self._federation_metrics_domains:
188171 last_pdu = pdus[-1]
189 last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
190 last_pdu_age_metric.labels(server_name=destination).set(
191 last_pdu_age / 1000
172 last_pdu_ts_metric.labels(server_name=destination).set(
173 last_pdu.origin_server_ts / 1000
192174 )
193
194 set_tag(tags.ERROR, not success)
195 return success
7272 "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
7373 )
7474 try:
75 self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
75 self.reactor.listenTCP(
76 self.hs.config.acme_port, srv, backlog=50, interface=host
77 )
7678 except twisted.internet.error.CannotListenError as e:
7779 check_bind_error(e, host, bind_addresses)
7880
6464 from synapse.types import JsonDict, Requester, UserID
6565 from synapse.util import stringutils as stringutils
6666 from synapse.util.async_helpers import maybe_awaitable
67 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
6768 from synapse.util.msisdn import phone_number_to_msisdn
6869 from synapse.util.threepids import canonicalise_email
6970
169170 extra_attributes = attr.ib(type=JsonDict)
170171
171172
173 @attr.s(slots=True, frozen=True)
174 class LoginTokenAttributes:
175 """Data we store in a short-term login token"""
176
177 user_id = attr.ib(type=str)
178
179 # the SSO Identity Provider that the user authenticated with, to get this token
180 auth_provider_id = attr.ib(type=str)
181
182
172183 class AuthHandler(BaseHandler):
173184 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
174185
325336 user is too high to proceed
326337
327338 """
328
339 if not requester.access_token_id:
340 raise ValueError("Cannot validate a user without an access token")
329341 if self._ui_auth_session_timeout:
330342 last_validated = await self.store.get_access_token_last_validated(
331343 requester.access_token_id
11631175 return None
11641176 return user_id
11651177
1166 async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
1167 auth_api = self.hs.get_auth()
1168 user_id = None
1178 async def validate_short_term_login_token(
1179 self, login_token: str
1180 ) -> LoginTokenAttributes:
11691181 try:
1170 macaroon = pymacaroons.Macaroon.deserialize(login_token)
1171 user_id = auth_api.get_user_id_from_macaroon(macaroon)
1172 auth_api.validate_macaroon(macaroon, "login", user_id)
1182 res = self.macaroon_gen.verify_short_term_login_token(login_token)
11731183 except Exception:
11741184 raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
11751185
1176 await self.auth.check_auth_blocking(user_id)
1177 return user_id
1186 await self.auth.check_auth_blocking(res.user_id)
1187 return res
11781188
11791189 async def delete_access_token(self, access_token: str):
11801190 """Invalidate a single access token
12031213 async def delete_access_tokens_for_user(
12041214 self,
12051215 user_id: str,
1206 except_token_id: Optional[str] = None,
1216 except_token_id: Optional[int] = None,
12071217 device_id: Optional[str] = None,
12081218 ):
12091219 """Invalidate access tokens belonging to a user
13961406 async def complete_sso_login(
13971407 self,
13981408 registered_user_id: str,
1409 auth_provider_id: str,
13991410 request: Request,
14001411 client_redirect_url: str,
14011412 extra_attributes: Optional[JsonDict] = None,
14051416
14061417 Args:
14071418 registered_user_id: The registered user ID to complete SSO login for.
1419 auth_provider_id: The id of the SSO Identity provider that was used for
1420 login. This will be stored in the login token for future tracking in
1421 prometheus metrics.
14081422 request: The request to complete.
14091423 client_redirect_url: The URL to which to redirect the user at the end of the
14101424 process.
14261440
14271441 self._complete_sso_login(
14281442 registered_user_id,
1443 auth_provider_id,
14291444 request,
14301445 client_redirect_url,
14311446 extra_attributes,
14361451 def _complete_sso_login(
14371452 self,
14381453 registered_user_id: str,
1454 auth_provider_id: str,
14391455 request: Request,
14401456 client_redirect_url: str,
14411457 extra_attributes: Optional[JsonDict] = None,
14621478
14631479 # Create a login token
14641480 login_token = self.macaroon_gen.generate_short_term_login_token(
1465 registered_user_id
1481 registered_user_id, auth_provider_id=auth_provider_id
14661482 )
14671483
14681484 # Append the login token to the original redirect URL (i.e. with its query
15681584 return macaroon.serialize()
15691585
15701586 def generate_short_term_login_token(
1571 self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
1587 self,
1588 user_id: str,
1589 auth_provider_id: str,
1590 duration_in_ms: int = (2 * 60 * 1000),
15721591 ) -> str:
15731592 macaroon = self._generate_base_macaroon(user_id)
15741593 macaroon.add_first_party_caveat("type = login")
15751594 now = self.hs.get_clock().time_msec()
15761595 expiry = now + duration_in_ms
15771596 macaroon.add_first_party_caveat("time < %d" % (expiry,))
1597 macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
15781598 return macaroon.serialize()
1599
1600 def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
1601 """Verify a short-term-login macaroon
1602
1603 Checks that the given token is a valid, unexpired short-term-login token
1604 minted by this server.
1605
1606 Args:
1607 token: the login token to verify
1608
1609 Returns:
1610 the user_id that this token is valid for
1611
1612 Raises:
1613 MacaroonVerificationFailedException if the verification failed
1614 """
1615 macaroon = pymacaroons.Macaroon.deserialize(token)
1616 user_id = get_value_from_macaroon(macaroon, "user_id")
1617 auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
1618
1619 v = pymacaroons.Verifier()
1620 v.satisfy_exact("gen = 1")
1621 v.satisfy_exact("type = login")
1622 v.satisfy_general(lambda c: c.startswith("user_id = "))
1623 v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
1624 satisfy_expiry(v, self.hs.get_clock().time_msec)
1625 v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
1626
1627 return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
15791628
15801629 def generate_delete_pusher_token(self, user_id: str) -> str:
15811630 macaroon = self._generate_base_macaroon(user_id)
8282 # the SsoIdentityProvider protocol type.
8383 self.idp_icon = None
8484 self.idp_brand = None
85 self.unstable_idp_brand = None
8586
8687 self._sso_handler = hs.get_sso_handler()
8788
200200 or pdu.internal_metadata.is_outlier()
201201 )
202202 if already_seen:
203 logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
203 logger.debug("Already seen pdu")
204204 return
205205
206206 # do some initial sanity-checking of the event. In particular, make
209209 try:
210210 self._sanity_check_event(pdu)
211211 except SynapseError as err:
212 logger.warning(
213 "[%s %s] Received event failed sanity checks", room_id, event_id
214 )
212 logger.warning("Received event failed sanity checks")
215213 raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
216214
217215 # If we are currently in the process of joining this room, then we
218216 # queue up events for later processing.
219217 if room_id in self.room_queues:
220218 logger.info(
221 "[%s %s] Queuing PDU from %s for now: join in progress",
222 room_id,
223 event_id,
219 "Queuing PDU from %s for now: join in progress",
224220 origin,
225221 )
226222 self.room_queues[room_id].append((pdu, origin))
235231 is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
236232 if not is_in_room:
237233 logger.info(
238 "[%s %s] Ignoring PDU from %s as we're not in the room",
239 room_id,
240 event_id,
234 "Ignoring PDU from %s as we're not in the room",
241235 origin,
242236 )
243237 return None
249243 # We only backfill backwards to the min depth.
250244 min_depth = await self.get_min_depth_for_context(pdu.room_id)
251245
252 logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
246 logger.debug("min_depth: %d", min_depth)
253247
254248 prevs = set(pdu.prev_event_ids())
255249 seen = await self.store.have_events_in_timeline(prevs)
266260 # If we're missing stuff, ensure we only fetch stuff one
267261 # at a time.
268262 logger.info(
269 "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
270 room_id,
271 event_id,
263 "Acquiring room lock to fetch %d missing prev_events: %s",
272264 len(missing_prevs),
273265 shortstr(missing_prevs),
274266 )
275267 with (await self._room_pdu_linearizer.queue(pdu.room_id)):
276268 logger.info(
277 "[%s %s] Acquired room lock to fetch %d missing prev_events",
278 room_id,
279 event_id,
269 "Acquired room lock to fetch %d missing prev_events",
280270 len(missing_prevs),
281271 )
282272
296286
297287 if not prevs - seen:
298288 logger.info(
299 "[%s %s] Found all missing prev_events",
300 room_id,
301 event_id,
289 "Found all missing prev_events",
302290 )
303291
304292 if prevs - seen:
328316
329317 if sent_to_us_directly:
330318 logger.warning(
331 "[%s %s] Rejecting: failed to fetch %d prev events: %s",
332 room_id,
333 event_id,
319 "Rejecting: failed to fetch %d prev events: %s",
334320 len(prevs - seen),
335321 shortstr(prevs - seen),
336322 )
366352 # Ask the remote server for the states we don't
367353 # know about
368354 for p in prevs - seen:
369 logger.info(
370 "Requesting state at missing prev_event %s",
371 event_id,
372 )
355 logger.info("Requesting state after missing prev_event %s", p)
373356
374357 with nested_logging_context(p):
375358 # note that if any of the missing prevs share missing state or
376359 # auth events, the requests to fetch those events are deduped
377360 # by the get_pdu_cache in federation_client.
378 (remote_state, _,) = await self._get_state_for_room(
379 origin, room_id, p, include_event_in_state=True
361 remote_state = (
362 await self._get_state_after_missing_prev_event(
363 origin, room_id, p
364 )
380365 )
381366
382367 remote_state_map = {
413398 state = [event_map[e] for e in state_map.values()]
414399 except Exception:
415400 logger.warning(
416 "[%s %s] Error attempting to resolve state at missing "
417 "prev_events",
418 room_id,
419 event_id,
401 "Error attempting to resolve state at missing " "prev_events",
420402 exc_info=True,
421403 )
422404 raise FederationError(
453435 latest |= seen
454436
455437 logger.info(
456 "[%s %s]: Requesting missing events between %s and %s",
457 room_id,
458 event_id,
438 "Requesting missing events between %s and %s",
459439 shortstr(latest),
460440 event_id,
461441 )
522502 # We failed to get the missing events, but since we need to handle
523503 # the case of `get_missing_events` not returning the necessary
524504 # events anyway, it is safe to simply log the error and continue.
525 logger.warning(
526 "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
527 )
505 logger.warning("Failed to get prev_events: %s", e)
528506 return
529507
530508 logger.info(
531 "[%s %s]: Got %d prev_events: %s",
532 room_id,
533 event_id,
509 "Got %d prev_events: %s",
534510 len(missing_events),
535511 shortstr(missing_events),
536512 )
541517
542518 for ev in missing_events:
543519 logger.info(
544 "[%s %s] Handling received prev_event %s",
545 room_id,
546 event_id,
520 "Handling received prev_event %s",
547521 ev.event_id,
548522 )
549523 with nested_logging_context(ev.event_id):
552526 except FederationError as e:
553527 if e.code == 403:
554528 logger.warning(
555 "[%s %s] Received prev_event %s failed history check.",
556 room_id,
557 event_id,
529 "Received prev_event %s failed history check.",
558530 ev.event_id,
559531 )
560532 else:
565537 destination: str,
566538 room_id: str,
567539 event_id: str,
568 include_event_in_state: bool = False,
569540 ) -> Tuple[List[EventBase], List[EventBase]]:
570541 """Requests all of the room state at a given event from a remote homeserver.
571542
573544 destination: The remote homeserver to query for the state.
574545 room_id: The id of the room we're interested in.
575546 event_id: The id of the event we want the state at.
576 include_event_in_state: if true, the event itself will be included in the
577 returned state event list.
578547
579548 Returns:
580 A list of events in the state, possibly including the event itself, and
549 A list of events in the state, not including the event itself, and
581550 a list of events in the auth chain for the given event.
582551 """
583552 (
588557 )
589558
590559 desired_events = set(state_event_ids + auth_event_ids)
591
592 if include_event_in_state:
593 desired_events.add(event_id)
594560
595561 event_map = await self._get_events_from_store_or_dest(
596562 destination, room_id, desired_events
607573 remote_state = [
608574 event_map[e_id] for e_id in state_event_ids if e_id in event_map
609575 ]
610
611 if include_event_in_state:
612 remote_event = event_map.get(event_id)
613 if not remote_event:
614 raise Exception("Unable to get missing prev_event %s" % (event_id,))
615 if remote_event.is_state() and remote_event.rejected_reason is None:
616 remote_state.append(remote_event)
617576
618577 auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
619578 auth_chain.sort(key=lambda e: e.depth)
688647
689648 return fetched_events
690649
650 async def _get_state_after_missing_prev_event(
651 self,
652 destination: str,
653 room_id: str,
654 event_id: str,
655 ) -> List[EventBase]:
656 """Requests all of the room state at a given event from a remote homeserver.
657
658 Args:
659 destination: The remote homeserver to query for the state.
660 room_id: The id of the room we're interested in.
661 event_id: The id of the event we want the state at.
662
663 Returns:
664 A list of events in the state, including the event itself
665 """
666 # TODO: This function is basically the same as _get_state_for_room. Can
667 # we make backfill() use it, rather than having two code paths? I think the
668 # only difference is that backfill() persists the prev events separately.
669
670 (
671 state_event_ids,
672 auth_event_ids,
673 ) = await self.federation_client.get_room_state_ids(
674 destination, room_id, event_id=event_id
675 )
676
677 logger.debug(
678 "state_ids returned %i state events, %i auth events",
679 len(state_event_ids),
680 len(auth_event_ids),
681 )
682
683 # start by just trying to fetch the events from the store
684 desired_events = set(state_event_ids)
685 desired_events.add(event_id)
686 logger.debug("Fetching %i events from cache/store", len(desired_events))
687 fetched_events = await self.store.get_events(
688 desired_events, allow_rejected=True
689 )
690
691 missing_desired_events = desired_events - fetched_events.keys()
692 logger.debug(
693 "We are missing %i events (got %i)",
694 len(missing_desired_events),
695 len(fetched_events),
696 )
697
698 # We probably won't need most of the auth events, so let's just check which
699 # we have for now, rather than thrashing the event cache with them all
700 # unnecessarily.
701
702 # TODO: we probably won't actually need all of the auth events, since we
703 # already have a bunch of the state events. It would be nice if the
704 # federation api gave us a way of finding out which we actually need.
705
706 missing_auth_events = set(auth_event_ids) - fetched_events.keys()
707 missing_auth_events.difference_update(
708 await self.store.have_seen_events(missing_auth_events)
709 )
710 logger.debug("We are also missing %i auth events", len(missing_auth_events))
711
712 missing_events = missing_desired_events | missing_auth_events
713 logger.debug("Fetching %i events from remote", len(missing_events))
714 await self._get_events_and_persist(
715 destination=destination, room_id=room_id, events=missing_events
716 )
717
718 # we need to make sure we re-load from the database to get the rejected
719 # state correct.
720 fetched_events.update(
721 (await self.store.get_events(missing_desired_events, allow_rejected=True))
722 )
723
724 # check for events which were in the wrong room.
725 #
726 # this can happen if a remote server claims that the state or
727 # auth_events at an event in room A are actually events in room B
728
729 bad_events = [
730 (event_id, event.room_id)
731 for event_id, event in fetched_events.items()
732 if event.room_id != room_id
733 ]
734
735 for bad_event_id, bad_room_id in bad_events:
736 # This is a bogus situation, but since we may only discover it a long time
737 # after it happened, we try our best to carry on, by just omitting the
738 # bad events from the returned state set.
739 logger.warning(
740 "Remote server %s claims event %s in room %s is an auth/state "
741 "event in room %s",
742 destination,
743 bad_event_id,
744 bad_room_id,
745 room_id,
746 )
747
748 del fetched_events[bad_event_id]
749
750 # if we couldn't get the prev event in question, that's a problem.
751 remote_event = fetched_events.get(event_id)
752 if not remote_event:
753 raise Exception("Unable to get missing prev_event %s" % (event_id,))
754
755 # missing state at that event is a warning, not a blocker
756 # XXX: this doesn't sound right? it means that we'll end up with incomplete
757 # state.
758 failed_to_fetch = desired_events - fetched_events.keys()
759 if failed_to_fetch:
760 logger.warning(
761 "Failed to fetch missing state events for %s %s",
762 event_id,
763 failed_to_fetch,
764 )
765
766 remote_state = [
767 fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
768 ]
769
770 if remote_event.is_state() and remote_event.rejected_reason is None:
771 remote_state.append(remote_event)
772
773 return remote_state
774
691775 async def _process_received_pdu(
692776 self,
693777 origin: str,
706790 (ie, we are missing one or more prev_events), the resolved state at the
707791 event
708792 """
709 room_id = event.room_id
710 event_id = event.event_id
711
712 logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
793 logger.debug("Processing event: %s", event)
713794
714795 try:
715796 await self._handle_new_event(origin, event, state=state)
870951 destination=dest,
871952 room_id=room_id,
872953 event_id=e_id,
873 include_event_in_state=False,
874954 )
875955 auth_events.update({a.event_id: a for a in auth})
876956 auth_events.update({s.event_id: s for s in state})
13161396 async def on_event_auth(self, event_id: str) -> List[EventBase]:
13171397 event = await self.store.get_event(event_id)
13181398 auth = await self.store.get_auth_chain(
1319 list(event.auth_event_ids()), include_given=True
1399 event.room_id, list(event.auth_event_ids()), include_given=True
13201400 )
13211401 return list(auth)
13221402
15791659 prev_state_ids = await context.get_prev_state_ids()
15801660
15811661 state_ids = list(prev_state_ids.values())
1582 auth_chain = await self.store.get_auth_chain(state_ids)
1662 auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
15831663
15841664 state = await self.store.get_events(list(prev_state_ids.values()))
15851665
22182298
22192299 # Now get the current auth_chain for the event.
22202300 local_auth_chain = await self.store.get_auth_chain(
2221 list(event.auth_event_ids()), include_given=True
2301 room_id, list(event.auth_event_ids()), include_given=True
22222302 )
22232303
22242304 # TODO: Check if we would now reject event_id. If so we need to tell
4747 self.clock = hs.get_clock()
4848 self.validator = EventValidator()
4949 self.snapshot_cache = ResponseCache(
50 hs, "initial_sync_cache"
50 hs.get_clock(), "initial_sync_cache"
5151 ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
5252 self._event_serializer = hs.get_event_client_serializer()
5353 self.storage = hs.get_storage()
00 # -*- coding: utf-8 -*-
11 # Copyright 2020 Quentin Gliech
2 # Copyright 2021 The Matrix.org Foundation C.I.C.
23 #
34 # Licensed under the Apache License, Version 2.0 (the "License");
45 # you may not use this file except in compliance with the License.
1314 # limitations under the License.
1415 import inspect
1516 import logging
16 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
17 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
1718 from urllib.parse import urlencode
1819
1920 import attr
2021 import pymacaroons
2122 from authlib.common.security import generate_token
22 from authlib.jose import JsonWebToken
23 from authlib.jose import JsonWebToken, jwt
2324 from authlib.oauth2.auth import ClientAuth
2425 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
2526 from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
2728 from jinja2 import Environment, Template
2829 from pymacaroons.exceptions import (
2930 MacaroonDeserializationException,
31 MacaroonInitException,
3032 MacaroonInvalidSignatureException,
3133 )
3234 from typing_extensions import TypedDict
3335
3436 from twisted.web.client import readBody
37 from twisted.web.http_headers import Headers
3538
3639 from synapse.config import ConfigError
37 from synapse.config.oidc_config import OidcProviderConfig
40 from synapse.config.oidc_config import (
41 OidcProviderClientSecretJwtKey,
42 OidcProviderConfig,
43 )
3844 from synapse.handlers.sso import MappingException, UserAttributes
3945 from synapse.http.site import SynapseRequest
4046 from synapse.logging.context import make_deferred_yieldable
4147 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
42 from synapse.util import json_decoder
48 from synapse.util import Clock, json_decoder
4349 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
50 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
4451
4552 if TYPE_CHECKING:
4653 from synapse.server import HomeServer
210217 session_data = self._token_generator.verify_oidc_session_token(
211218 session, state
212219 )
213 except (MacaroonDeserializationException, ValueError) as e:
220 except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
214221 logger.exception("Invalid session for OIDC callback")
215222 self._sso_handler.render_error(request, "invalid_session", str(e))
216223 return
274281
275282 self._scopes = provider.scopes
276283 self._user_profile_method = provider.user_profile_method
284
285 client_secret = None # type: Union[None, str, JwtClientSecret]
286 if provider.client_secret:
287 client_secret = provider.client_secret
288 elif provider.client_secret_jwt_key:
289 client_secret = JwtClientSecret(
290 provider.client_secret_jwt_key,
291 provider.client_id,
292 provider.issuer,
293 hs.get_clock(),
294 )
295
277296 self._client_auth = ClientAuth(
278297 provider.client_id,
279 provider.client_secret,
298 client_secret,
280299 provider.client_auth_method,
281300 ) # type: ClientAuth
282301 self._client_auth_method = provider.client_auth_method
310329
311330 # optional brand identifier for this auth provider
312331 self.idp_brand = provider.idp_brand
332
333 # Optional brand identifier for the unstable API (see MSC2858).
334 self.unstable_idp_brand = provider.unstable_idp_brand
313335
314336 self._sso_handler = hs.get_sso_handler()
315337
520542 """
521543 metadata = await self.load_metadata()
522544 token_endpoint = metadata.get("token_endpoint")
523 headers = {
545 raw_headers = {
524546 "Content-Type": "application/x-www-form-urlencoded",
525547 "User-Agent": self._http_client.user_agent,
526548 "Accept": "application/json",
534556 body = urlencode(args, True)
535557
536558 # Fill the body/headers with credentials
537 uri, headers, body = self._client_auth.prepare(
538 method="POST", uri=token_endpoint, headers=headers, body=body
539 )
540 headers = {k: [v] for (k, v) in headers.items()}
559 uri, raw_headers, body = self._client_auth.prepare(
560 method="POST", uri=token_endpoint, headers=raw_headers, body=body
561 )
562 headers = Headers({k: [v] for (k, v) in raw_headers.items()})
541563
542564 # Do the actual request
543565 # We're not using the SimpleHttpClient util methods as we don't want to
744766 idp_id=self.idp_id,
745767 nonce=nonce,
746768 client_redirect_url=client_redirect_url.decode(),
747 ui_auth_session_id=ui_auth_session_id,
769 ui_auth_session_id=ui_auth_session_id or "",
748770 ),
749771 )
750772
975997 return str(remote_user_id)
976998
977999
1000 # number of seconds a newly-generated client secret should be valid for
1001 CLIENT_SECRET_VALIDITY_SECONDS = 3600
1002
1003 # minimum remaining validity on a client secret before we should generate a new one
1004 CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
1005
1006
1007 class JwtClientSecret:
1008 """A class which generates a new client secret on demand, based on a JWK
1009
1010 This implementation is designed to comply with the requirements for Apple Sign in:
1011 https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
1012
1013 It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
1014 but it's worth noting that we still put the generated secret in the "client_secret"
1015 field (or rather, whereever client_auth_method puts it) rather than in a
1016 client_assertion field in the body as that RFC seems to require.
1017 """
1018
1019 def __init__(
1020 self,
1021 key: OidcProviderClientSecretJwtKey,
1022 oauth_client_id: str,
1023 oauth_issuer: str,
1024 clock: Clock,
1025 ):
1026 self._key = key
1027 self._oauth_client_id = oauth_client_id
1028 self._oauth_issuer = oauth_issuer
1029 self._clock = clock
1030 self._cached_secret = b""
1031 self._cached_secret_replacement_time = 0
1032
1033 def __str__(self):
1034 # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
1035 # encode_client_secret_basic, which calls "{}".format(secret), which ends up
1036 # here.
1037 return self._get_secret().decode("ascii")
1038
1039 def __bytes__(self):
1040 # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
1041 # encode_client_secret_post, which ends up here.
1042 return self._get_secret()
1043
1044 def _get_secret(self) -> bytes:
1045 now = self._clock.time()
1046
1047 # if we have enough validity on our existing secret, use it
1048 if now < self._cached_secret_replacement_time:
1049 return self._cached_secret
1050
1051 issued_at = int(now)
1052 expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
1053
1054 # we copy the configured header because jwt.encode modifies it.
1055 header = dict(self._key.jwt_header)
1056
1057 # see https://tools.ietf.org/html/rfc7523#section-3
1058 payload = {
1059 "sub": self._oauth_client_id,
1060 "aud": self._oauth_issuer,
1061 "iat": issued_at,
1062 "exp": expires_at,
1063 **self._key.jwt_payload,
1064 }
1065 logger.info(
1066 "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
1067 )
1068 self._cached_secret = jwt.encode(header, payload, self._key.key)
1069 self._cached_secret_replacement_time = (
1070 expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
1071 )
1072 return self._cached_secret
1073
1074
9781075 class OidcSessionTokenGenerator:
9791076 """Methods for generating and checking OIDC Session cookies."""
9801077
10191116 macaroon.add_first_party_caveat(
10201117 "client_redirect_url = %s" % (session_data.client_redirect_url,)
10211118 )
1022 if session_data.ui_auth_session_id:
1023 macaroon.add_first_party_caveat(
1024 "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
1025 )
1119 macaroon.add_first_party_caveat(
1120 "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
1121 )
10261122 now = self._clock.time_msec()
10271123 expiry = now + duration_in_ms
10281124 macaroon.add_first_party_caveat("time < %d" % (expiry,))
10451141 The data extracted from the session cookie
10461142
10471143 Raises:
1048 ValueError if an expected caveat is missing from the macaroon.
1144 KeyError if an expected caveat is missing from the macaroon.
10491145 """
10501146 macaroon = pymacaroons.Macaroon.deserialize(session)
10511147
10561152 v.satisfy_general(lambda c: c.startswith("nonce = "))
10571153 v.satisfy_general(lambda c: c.startswith("idp_id = "))
10581154 v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
1059 # Sometimes there's a UI auth session ID, it seems to be OK to attempt
1060 # to always satisfy this.
10611155 v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
1062 v.satisfy_general(self._verify_expiry)
1156 satisfy_expiry(v, self._clock.time_msec)
10631157
10641158 v.verify(macaroon, self._macaroon_secret_key)
10651159
10661160 # Extract the session data from the token.
1067 nonce = self._get_value_from_macaroon(macaroon, "nonce")
1068 idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
1069 client_redirect_url = self._get_value_from_macaroon(
1070 macaroon, "client_redirect_url"
1071 )
1072 try:
1073 ui_auth_session_id = self._get_value_from_macaroon(
1074 macaroon, "ui_auth_session_id"
1075 ) # type: Optional[str]
1076 except ValueError:
1077 ui_auth_session_id = None
1078
1161 nonce = get_value_from_macaroon(macaroon, "nonce")
1162 idp_id = get_value_from_macaroon(macaroon, "idp_id")
1163 client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
1164 ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
10791165 return OidcSessionData(
10801166 nonce=nonce,
10811167 idp_id=idp_id,
10831169 ui_auth_session_id=ui_auth_session_id,
10841170 )
10851171
1086 def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
1087 """Extracts a caveat value from a macaroon token.
1088
1089 Args:
1090 macaroon: the token
1091 key: the key of the caveat to extract
1092
1093 Returns:
1094 The extracted value
1095
1096 Raises:
1097 ValueError: if the caveat was not in the macaroon
1098 """
1099 prefix = key + " = "
1100 for caveat in macaroon.caveats:
1101 if caveat.caveat_id.startswith(prefix):
1102 return caveat.caveat_id[len(prefix) :]
1103 raise ValueError("No %s caveat in macaroon" % (key,))
1104
1105 def _verify_expiry(self, caveat: str) -> bool:
1106 prefix = "time < "
1107 if not caveat.startswith(prefix):
1108 return False
1109 expiry = int(caveat[len(prefix) :])
1110 now = self._clock.time_msec()
1111 return now < expiry
1112
11131172
11141173 @attr.s(frozen=True, slots=True)
11151174 class OidcSessionData:
11241183 # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
11251184 client_redirect_url = attr.ib(type=str)
11261185
1127 # The session ID of the ongoing UI Auth (None if this is a login)
1128 ui_auth_session_id = attr.ib(type=Optional[str], default=None)
1186 # The session ID of the ongoing UI Auth ("" if this is a login)
1187 ui_auth_session_id = attr.ib(type=str)
11291188
11301189
11311190 UserAttributeDict = TypedDict(
284284 except Exception:
285285 f = Failure()
286286 logger.error(
287 "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
287 "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
288288 )
289289 self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
290290 finally:
1515 """Contains functions for registering clients."""
1616
1717 import logging
18 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
18 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
19
20 from prometheus_client import Counter
1921
2022 from synapse import types
2123 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
4042 logger = logging.getLogger(__name__)
4143
4244
45 registration_counter = Counter(
46 "synapse_user_registrations_total",
47 "Number of new users registered (since restart)",
48 ["guest", "shadow_banned", "auth_provider"],
49 )
50
51 login_counter = Counter(
52 "synapse_user_logins_total",
53 "Number of user logins (since restart)",
54 ["guest", "auth_provider"],
55 )
56
57
4358 class RegistrationHandler(BaseHandler):
4459 def __init__(self, hs: "HomeServer"):
4560 super().__init__(hs)
6681 )
6782 else:
6883 self.device_handler = hs.get_device_handler()
84 self._register_device_client = self.register_device_inner
6985 self.pusher_pool = hs.get_pusherpool()
7086
7187 self.session_lifetime = hs.config.session_lifetime
155171 bind_emails: Iterable[str] = [],
156172 by_admin: bool = False,
157173 user_agent_ips: Optional[List[Tuple[str, str]]] = None,
174 auth_provider_id: Optional[str] = None,
158175 ) -> str:
159176 """Registers a new client on the server.
160177
180197 admin api, otherwise False.
181198 user_agent_ips: Tuples of IP addresses and user-agents used
182199 during the registration process.
200 auth_provider_id: The SSO IdP the user used, if any.
183201 Returns:
184 The registere user_id.
202 The registered user_id.
185203 Raises:
186204 SynapseError if there was a problem registering.
187205 """
191209 threepid,
192210 localpart,
193211 user_agent_ips or [],
212 auth_provider_id=auth_provider_id,
194213 )
195214
196215 if result == RegistrationBehaviour.DENY:
278297 except SynapseError:
279298 # if user id is taken, just generate another
280299 fail_count += 1
300
301 registration_counter.labels(
302 guest=make_guest,
303 shadow_banned=shadow_banned,
304 auth_provider=(auth_provider_id or ""),
305 ).inc()
281306
282307 if not self.hs.config.user_consent_at_registration:
283308 if not self.hs.config.auto_join_rooms_for_guests and make_guest:
637662 initial_display_name: Optional[str],
638663 is_guest: bool = False,
639664 is_appservice_ghost: bool = False,
665 auth_provider_id: Optional[str] = None,
640666 ) -> Tuple[str, str]:
641667 """Register a device for a user and generate an access token.
642668
647673 device_id: The device ID to check, or None to generate a new one.
648674 initial_display_name: An optional display name for the device.
649675 is_guest: Whether this is a guest account
650
676 auth_provider_id: The SSO IdP the user used, if any (just used for the
677 prometheus metrics).
651678 Returns:
652679 Tuple of device ID and access token
653680 """
654
655 if self.hs.config.worker_app:
656 r = await self._register_device_client(
657 user_id=user_id,
658 device_id=device_id,
659 initial_display_name=initial_display_name,
660 is_guest=is_guest,
661 is_appservice_ghost=is_appservice_ghost,
662 )
663 return r["device_id"], r["access_token"]
664
681 res = await self._register_device_client(
682 user_id=user_id,
683 device_id=device_id,
684 initial_display_name=initial_display_name,
685 is_guest=is_guest,
686 is_appservice_ghost=is_appservice_ghost,
687 )
688
689 login_counter.labels(
690 guest=is_guest,
691 auth_provider=(auth_provider_id or ""),
692 ).inc()
693
694 return res["device_id"], res["access_token"]
695
696 async def register_device_inner(
697 self,
698 user_id: str,
699 device_id: Optional[str],
700 initial_display_name: Optional[str],
701 is_guest: bool = False,
702 is_appservice_ghost: bool = False,
703 ) -> Dict[str, str]:
704 """Helper for register_device
705
706 Does the bits that need doing on the main process. Not for use outside this
707 class and RegisterDeviceReplicationServlet.
708 """
709 assert not self.hs.config.worker_app
665710 valid_until_ms = None
666711 if self.session_lifetime is not None:
667712 if is_guest:
686731 is_appservice_ghost=is_appservice_ghost,
687732 )
688733
689 return (registered_device_id, access_token)
734 return {"device_id": registered_device_id, "access_token": access_token}
690735
691736 async def post_registration_actions(
692737 self, user_id: str, auth_result: dict, access_token: Optional[str]
120120 # succession, only process the first attempt and return its result to
121121 # subsequent requests
122122 self._upgrade_response_cache = ResponseCache(
123 hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
123 hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
124124 ) # type: ResponseCache[Tuple[str, str]]
125125 self._server_notices_mxid = hs.config.server_notices_mxid
126126
4343 super().__init__(hs)
4444 self.enable_room_list_search = hs.config.enable_room_list_search
4545 self.response_cache = ResponseCache(
46 hs, "room_list"
46 hs.get_clock(), "room_list"
4747 ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
4848 self.remote_response_cache = ResponseCache(
49 hs, "remote_room_list", timeout_ms=30 * 1000
49 hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
5050 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
5151
5252 async def get_local_public_room_list(
8080 # the SsoIdentityProvider protocol type.
8181 self.idp_icon = None
8282 self.idp_brand = None
83 self.unstable_idp_brand = None
8384
8485 # a map from saml session id to Saml2SessionData object
8586 self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
9797 """Optional branding identifier"""
9898 return None
9999
100 @property
101 def unstable_idp_brand(self) -> Optional[str]:
102 """Optional brand identifier for the unstable API (see MSC2858)."""
103 return None
104
100105 @abc.abstractmethod
101106 async def handle_redirect_request(
102107 self,
455460
456461 await self._auth_handler.complete_sso_login(
457462 user_id,
463 auth_provider_id,
458464 request,
459465 client_redirect_url,
460466 extra_login_attributes,
604610 default_display_name=attributes.display_name,
605611 bind_emails=attributes.emails,
606612 user_agent_ips=[(user_agent, ip_address)],
613 auth_provider_id=auth_provider_id,
607614 )
608615
609616 await self._store.record_user_external_id(
885892
886893 await self._auth_handler.complete_sso_login(
887894 user_id,
895 session.auth_provider_id,
888896 request,
889897 session.client_redirect_url,
890898 session.extra_login_attributes,
243243 self.event_sources = hs.get_event_sources()
244244 self.clock = hs.get_clock()
245245 self.response_cache = ResponseCache(
246 hs, "sync"
246 hs.get_clock(), "sync"
247247 ) # type: ResponseCache[Tuple[Any, ...]]
248248 self.state = hs.get_state_handler()
249249 self.auth = hs.get_auth()
3838 from OpenSSL import SSL
3939 from OpenSSL.SSL import VERIFY_NONE
4040 from twisted.internet import defer, error as twisted_error, protocol, ssl
41 from twisted.internet.address import IPv4Address, IPv6Address
4142 from twisted.internet.interfaces import (
4243 IAddress,
4344 IHostResolution,
4445 IReactorPluggableNameResolver,
4546 IResolutionReceiver,
47 ITCPTransport,
4648 )
49 from twisted.internet.protocol import connectionDone
4750 from twisted.internet.task import Cooperator
4851 from twisted.python.failure import Failure
4952 from twisted.web._newclient import ResponseDone
5558 )
5659 from twisted.web.http import PotentialDataLoss
5760 from twisted.web.http_headers import Headers
58 from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
61 from twisted.web.iweb import (
62 UNKNOWN_LENGTH,
63 IAgent,
64 IBodyProducer,
65 IPolicyForHTTPS,
66 IResponse,
67 )
5968
6069 from synapse.api.errors import Codes, HttpResponseException, SynapseError
6170 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
6271 from synapse.http.proxyagent import ProxyAgent
6372 from synapse.logging.context import make_deferred_yieldable
6473 from synapse.logging.opentracing import set_tag, start_active_span, tags
74 from synapse.types import ISynapseReactor
6575 from synapse.util import json_decoder
6676 from synapse.util.async_helpers import timeout_deferred
6777
149159 def resolveHostName(
150160 self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
151161 ) -> IResolutionReceiver:
152
153 r = recv()
154162 addresses = [] # type: List[IAddress]
155163
156164 def _callback() -> None:
157 r.resolutionBegan(None)
158
159165 has_bad_ip = False
160 for i in addresses:
161 ip_address = IPAddress(i.host)
166 for address in addresses:
167 # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
168 # should go through this path.
169 if not isinstance(address, (IPv4Address, IPv6Address)):
170 continue
171
172 ip_address = IPAddress(address.host)
162173
163174 if check_against_blacklist(
164175 ip_address, self._ip_whitelist, self._ip_blacklist
173184 # request, but all we can really do from here is claim that there were no
174185 # valid results.
175186 if not has_bad_ip:
176 for i in addresses:
177 r.addressResolved(i)
178 r.resolutionComplete()
187 for address in addresses:
188 recv.addressResolved(address)
189 recv.resolutionComplete()
179190
180191 @provider(IResolutionReceiver)
181192 class EndpointReceiver:
182193 @staticmethod
183194 def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
184 pass
195 recv.resolutionBegan(resolutionInProgress)
185196
186197 @staticmethod
187198 def addressResolved(address: IAddress) -> None:
195206 EndpointReceiver, hostname, portNumber=portNumber
196207 )
197208
198 return r
199
200
201 @implementer(IReactorPluggableNameResolver)
209 return recv
210
211
212 @implementer(ISynapseReactor)
202213 class BlacklistingReactorWrapper:
203214 """
204215 A Reactor wrapper which will prevent DNS resolution to blacklisted IP
323334 # filters out blacklisted IP addresses, to prevent DNS rebinding.
324335 self.reactor = BlacklistingReactorWrapper(
325336 hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
326 )
337 ) # type: ISynapseReactor
327338 else:
328339 self.reactor = hs.get_reactor()
329340
344355 contextFactory=self.hs.get_http_client_context_factory(),
345356 pool=pool,
346357 use_proxy=use_proxy,
347 )
358 ) # type: IAgent
348359
349360 if self._ip_blacklist:
350361 # If we have an IP blacklist, we then install the blacklisting Agent
750761 class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
751762 """A protocol which immediately errors upon receiving data."""
752763
764 transport = None # type: Optional[ITCPTransport]
765
753766 def __init__(self, deferred: defer.Deferred):
754767 self.deferred = deferred
755768
761774 self.deferred.errback(BodyExceededMaxSize())
762775 # Close the connection (forcefully) since all the data will get
763776 # discarded anyway.
777 assert self.transport is not None
764778 self.transport.abortConnection()
765779
766780 def dataReceived(self, data: bytes) -> None:
767781 self._maybe_fail()
768782
769 def connectionLost(self, reason: Failure) -> None:
783 def connectionLost(self, reason: Failure = connectionDone) -> None:
770784 self._maybe_fail()
771785
772786
773787 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
774788 """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
789
790 transport = None # type: Optional[ITCPTransport]
775791
776792 def __init__(
777793 self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
795811 self.deferred.errback(BodyExceededMaxSize())
796812 # Close the connection (forcefully) since all the data will get
797813 # discarded anyway.
814 assert self.transport is not None
798815 self.transport.abortConnection()
799816
800 def connectionLost(self, reason: Failure) -> None:
817 def connectionLost(self, reason: Failure = connectionDone) -> None:
801818 # If the maximum size was already exceeded, there's nothing to do.
802819 if self.deferred.called:
803820 return
866883 return query_str.encode("utf8")
867884
868885
886 @implementer(IPolicyForHTTPS)
869887 class InsecureInterceptableContextFactory(ssl.ContextFactory):
870888 """
871889 Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
3434 from synapse.http.federation.srv_resolver import Server, SrvResolver
3535 from synapse.http.federation.well_known_resolver import WellKnownResolver
3636 from synapse.logging.context import make_deferred_yieldable, run_in_background
37 from synapse.types import ISynapseReactor
3738 from synapse.util import Clock
3839
3940 logger = logging.getLogger(__name__)
6768
6869 def __init__(
6970 self,
70 reactor: IReactorCore,
71 reactor: ISynapseReactor,
7172 tls_client_options_factory: Optional[FederationPolicyForHTTPS],
7273 user_agent: bytes,
7374 ip_blacklist: IPSet,
321321
322322 def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
323323 cache_controls = {}
324 for hdr in headers.getRawHeaders(b"cache-control", []):
324 cache_control_headers = headers.getRawHeaders(b"cache-control") or []
325 for hdr in cache_control_headers:
325326 for directive in hdr.split(b","):
326327 splits = [x.strip() for x in directive.split(b"=", 1)]
327328 k = splits[0].lower()
5858 start_active_span,
5959 tags,
6060 )
61 from synapse.types import JsonDict
61 from synapse.types import ISynapseReactor, JsonDict
6262 from synapse.util import json_decoder
6363 from synapse.util.async_helpers import timeout_deferred
6464 from synapse.util.metrics import Measure
236236 # addresses, to prevent DNS rebinding.
237237 self.reactor = BlacklistingReactorWrapper(
238238 hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
239 )
239 ) # type: ISynapseReactor
240240
241241 user_agent = hs.version_string
242242 if hs.config.user_agent_suffix:
243243 user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
244244 user_agent = user_agent.encode("ascii")
245245
246 self.agent = MatrixFederationAgent(
246 federation_agent = MatrixFederationAgent(
247247 self.reactor,
248248 tls_client_options_factory,
249249 user_agent,
253253 # Use a BlacklistingAgentWrapper to prevent circumventing the IP
254254 # blacklist via IP literals in server names
255255 self.agent = BlacklistingAgentWrapper(
256 self.agent,
256 federation_agent,
257257 ip_blacklist=hs.config.federation_ip_range_blacklist,
258258 )
259259
533533 response.code, response_phrase, body
534534 )
535535
536 # Retry if the error is a 429 (Too Many Requests),
537 # otherwise just raise a standard HttpResponseException
538 if response.code == 429:
536 # Retry if the error is a 5xx or a 429 (Too Many
537 # Requests), otherwise just raise a standard
538 # `HttpResponseException`
539 if 500 <= response.code < 600 or response.code == 429:
539540 raise RequestSendFailed(exc, can_retry=True) from exc
540541 else:
541542 raise exc
3131 TCP4ClientEndpoint,
3232 TCP6ClientEndpoint,
3333 )
34 from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
34 from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
3535 from twisted.internet.protocol import Factory, Protocol
36 from twisted.internet.tcp import Connection
3637 from twisted.python.failure import Failure
3738
3839 logger = logging.getLogger(__name__)
5152 format: A callable to format the log record to a string.
5253 """
5354
54 transport = attr.ib(type=ITransport)
55 # This is essentially ITCPTransport, but that is missing certain fields
56 # (connected and registerProducer) which are part of the implementation.
57 transport = attr.ib(type=Connection)
5558 _format = attr.ib(type=Callable[[logging.LogRecord], str])
5659 _buffer = attr.ib(type=deque)
5760 _paused = attr.ib(default=False, type=bool, init=False)
148151 if self._connection_waiter:
149152 return
150153
151 self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
152
153154 def fail(failure: Failure) -> None:
154155 # If the Deferred was cancelled (e.g. during shutdown) do not try to
155156 # reconnect (this will cause an infinite loop of errors).
162163 self._connect()
163164
164165 def writer(result: Protocol) -> None:
166 # Force recognising transport as a Connection and not the more
167 # generic ITransport.
168 transport = result.transport # type: Connection # type: ignore
169
165170 # We have a connection. If we already have a producer, and its
166171 # transport is the same, just trigger a resumeProducing.
167 if self._producer and result.transport is self._producer.transport:
172 if self._producer and transport is self._producer.transport:
168173 self._producer.resumeProducing()
169174 self._connection_waiter = None
170175 return
176181 # Make a new producer and start it.
177182 self._producer = LogProducer(
178183 buffer=self._buffer,
179 transport=result.transport,
184 transport=transport,
180185 format=self.format,
181186 )
182 result.transport.registerProducer(self._producer, True)
187 transport.registerProducer(self._producer, True)
183188 self._producer.resumeProducing()
184189 self._connection_waiter = None
185190
186 self._connection_waiter.addCallbacks(writer, fail)
191 deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
192 deferred.addCallbacks(writer, fail)
193 self._connection_waiter = deferred
187194
188195 def _handle_pressure(self) -> None:
189196 """
668668 return g
669669
670670
671 def run_in_background(f, *args, **kwargs):
671 def run_in_background(f, *args, **kwargs) -> defer.Deferred:
672672 """Calls a function, ensuring that the current context is restored after
673673 return from the function, and that the sentinel context is set once the
674674 deferred returned by the function completes.
696696 if isinstance(res, types.CoroutineType):
697697 res = defer.ensureDeferred(res)
698698
699 # At this point we should have a Deferred, if not then f was a synchronous
700 # function, wrap it in a Deferred for consistency.
699701 if not isinstance(res, defer.Deferred):
700 return res
702 return defer.succeed(res)
701703
702704 if res.called and not res.paused:
703705 # The function should have maintained the logcontext, so we can
202202 )
203203
204204 def generate_short_term_login_token(
205 self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
205 self,
206 user_id: str,
207 duration_in_ms: int = (2 * 60 * 1000),
208 auth_provider_id: str = "",
206209 ) -> str:
207 """Generate a login token suitable for m.login.token authentication"""
210 """Generate a login token suitable for m.login.token authentication
211
212 Args:
213 user_id: gives the ID of the user that the token is for
214
215 duration_in_ms: the time that the token will be valid for
216
217 auth_provider_id: the ID of the SSO IdP that the user used to authenticate
218 to get this token, if any. This is encoded in the token so that
219 /login can report stats on number of successful logins by IdP.
220 """
208221 return self._hs.get_macaroon_generator().generate_short_term_login_token(
209 user_id, duration_in_ms
222 user_id,
223 auth_provider_id,
224 duration_in_ms,
210225 )
211226
212227 @defer.inlineCallbacks
275290 """
276291 self._auth_handler._complete_sso_login(
277292 registered_user_id,
293 "<unknown>",
278294 request,
279295 client_redirect_url,
280296 )
285301 request: SynapseRequest,
286302 client_redirect_url: str,
287303 new_user: bool = False,
304 auth_provider_id: str = "<unknown>",
288305 ):
289306 """Complete a SSO login by redirecting the user to a page to confirm whether they
290307 want their access token sent to `client_redirect_url`, or redirect them to that
298315 redirect them directly if whitelisted).
299316 new_user: set to true to use wording for the consent appropriate to a user
300317 who has just registered.
318 auth_provider_id: the ID of the SSO IdP which was used to log in. This
319 is used to track counts of sucessful logins by IdP.
301320 """
302321 await self._auth_handler.complete_sso_login(
303 registered_user_id, request, client_redirect_url, new_user=new_user
322 registered_user_id,
323 auth_provider_id,
324 request,
325 client_redirect_url,
326 new_user=new_user,
304327 )
305328
306329 @defer.inlineCallbacks
1515 import logging
1616 from typing import TYPE_CHECKING, Dict, List, Optional
1717
18 from twisted.internet.base import DelayedCall
1918 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
19 from twisted.internet.interfaces import IDelayedCall
2020
2121 from synapse.metrics.background_process_metrics import run_as_background_process
2222 from synapse.push import Pusher, PusherConfig, ThrottleParams
6565
6666 self.store = self.hs.get_datastore()
6767 self.email = pusher_config.pushkey
68 self.timed_call = None # type: Optional[DelayedCall]
68 self.timed_call = None # type: Optional[IDelayedCall]
6969 self.throttle_params = {} # type: Dict[str, ThrottleParams]
7070 self._inited = False
7171
1717 import re
1818 import urllib
1919 from inspect import signature
20 from typing import Dict, List, Tuple
20 from typing import TYPE_CHECKING, Dict, List, Tuple
2121
2222 from prometheus_client import Counter, Gauge
2323
2626 from synapse.logging.opentracing import inject_active_span_byte_dict, trace
2727 from synapse.util.caches.response_cache import ResponseCache
2828 from synapse.util.stringutils import random_string
29
30 if TYPE_CHECKING:
31 from synapse.server import HomeServer
2932
3033 logger = logging.getLogger(__name__)
3134
8790 CACHE = True
8891 RETRY_ON_TIMEOUT = True
8992
90 def __init__(self, hs):
93 def __init__(self, hs: "HomeServer"):
9194 if self.CACHE:
9295 self.response_cache = ResponseCache(
93 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
96 hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
9497 ) # type: ResponseCache[str]
9598
9699 # We reserve `instance_name` as a parameter to sending requests, so we
6060 is_guest = content["is_guest"]
6161 is_appservice_ghost = content["is_appservice_ghost"]
6262
63 device_id, access_token = await self.registration_handler.register_device(
63 res = await self.registration_handler.register_device_inner(
6464 user_id,
6565 device_id,
6666 initial_display_name,
6868 is_appservice_ghost=is_appservice_ghost,
6969 )
7070
71 return 200, {"device_id": device_id, "access_token": access_token}
71 return 200, res
7272
7373
7474 def register_servlets(hs, http_server):
4747 UserIpCommand,
4848 UserSyncCommand,
4949 )
50 from synapse.replication.tcp.protocol import AbstractConnection
50 from synapse.replication.tcp.protocol import IReplicationConnection
5151 from synapse.replication.tcp.streams import (
5252 STREAMS_MAP,
5353 AccountDataStream,
8181
8282 # the type of the entries in _command_queues_by_stream
8383 _StreamCommandQueue = Deque[
84 Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
84 Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
8585 ]
8686
8787
173173
174174 # The currently connected connections. (The list of places we need to send
175175 # outgoing replication commands to.)
176 self._connections = [] # type: List[AbstractConnection]
176 self._connections = [] # type: List[IReplicationConnection]
177177
178178 LaterGauge(
179179 "synapse_replication_tcp_resource_total_connections",
196196
197197 # For each connection, the incoming stream names that have received a POSITION
198198 # from that connection.
199 self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
199 self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
200200
201201 LaterGauge(
202202 "synapse_replication_tcp_command_queue",
219219 self._server_notices_sender = hs.get_server_notices_sender()
220220
221221 def _add_command_to_stream_queue(
222 self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
222 self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
223223 ) -> None:
224224 """Queue the given received command for processing
225225
266266 async def _process_command(
267267 self,
268268 cmd: Union[PositionCommand, RdataCommand],
269 conn: AbstractConnection,
269 conn: IReplicationConnection,
270270 stream_name: str,
271271 ) -> None:
272272 if isinstance(cmd, PositionCommand):
301301 hs, outbound_redis_connection
302302 )
303303 hs.get_reactor().connectTCP(
304 hs.config.redis.redis_host,
304 hs.config.redis.redis_host.encode(),
305305 hs.config.redis.redis_port,
306306 self._factory,
307307 )
310310 self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
311311 host = hs.config.worker_replication_host
312312 port = hs.config.worker_replication_port
313 hs.get_reactor().connectTCP(host, port, self._factory)
313 hs.get_reactor().connectTCP(host.encode(), port, self._factory)
314314
315315 def get_streams(self) -> Dict[str, Stream]:
316316 """Get a map from stream name to all streams."""
320320 """Get a list of streams that this instances replicates."""
321321 return self._streams_to_replicate
322322
323 def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
323 def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
324324 self.send_positions_to_connection(conn)
325325
326 def send_positions_to_connection(self, conn: AbstractConnection):
326 def send_positions_to_connection(self, conn: IReplicationConnection):
327327 """Send current position of all streams this process is source of to
328328 the connection.
329329 """
346346 )
347347
348348 def on_USER_SYNC(
349 self, conn: AbstractConnection, cmd: UserSyncCommand
349 self, conn: IReplicationConnection, cmd: UserSyncCommand
350350 ) -> Optional[Awaitable[None]]:
351351 user_sync_counter.inc()
352352
358358 return None
359359
360360 def on_CLEAR_USER_SYNC(
361 self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
361 self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
362362 ) -> Optional[Awaitable[None]]:
363363 if self._is_master:
364364 return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
365365 else:
366366 return None
367367
368 def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
368 def on_FEDERATION_ACK(
369 self, conn: IReplicationConnection, cmd: FederationAckCommand
370 ):
369371 federation_ack_counter.inc()
370372
371373 if self._federation_sender:
372374 self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
373375
374376 def on_USER_IP(
375 self, conn: AbstractConnection, cmd: UserIpCommand
377 self, conn: IReplicationConnection, cmd: UserIpCommand
376378 ) -> Optional[Awaitable[None]]:
377379 user_ip_cache_counter.inc()
378380
394396 assert self._server_notices_sender is not None
395397 await self._server_notices_sender.on_user_ip(cmd.user_id)
396398
397 def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
399 def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
398400 if cmd.instance_name == self._instance_name:
399401 # Ignore RDATA that are just our own echoes
400402 return
411413 self._add_command_to_stream_queue(conn, cmd)
412414
413415 async def _process_rdata(
414 self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
416 self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
415417 ) -> None:
416418 """Process an RDATA command
417419
485487 stream_name, instance_name, token, rows
486488 )
487489
488 def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
490 def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
489491 if cmd.instance_name == self._instance_name:
490492 # Ignore POSITION that are just our own echoes
491493 return
495497 self._add_command_to_stream_queue(conn, cmd)
496498
497499 async def _process_position(
498 self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
500 self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
499501 ) -> None:
500502 """Process a POSITION command
501503
552554
553555 self._streams_by_connection.setdefault(conn, set()).add(stream_name)
554556
555 def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
557 def on_REMOTE_SERVER_UP(
558 self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
559 ):
556560 """"Called when get a new REMOTE_SERVER_UP command."""
557561 self._replication_data_handler.on_remote_server_up(cmd.data)
558562
575579 # between two instances, but that is not currently supported).
576580 self.send_command(cmd, ignore_conn=conn)
577581
578 def new_connection(self, connection: AbstractConnection):
582 def new_connection(self, connection: IReplicationConnection):
579583 """Called when we have a new connection."""
580584 self._connections.append(connection)
581585
602606 UserSyncCommand(self._instance_id, user_id, True, now)
603607 )
604608
605 def lost_connection(self, connection: AbstractConnection):
609 def lost_connection(self, connection: IReplicationConnection):
606610 """Called when a connection is closed/lost."""
607611 # we no longer need _streams_by_connection for this connection.
608612 streams = self._streams_by_connection.pop(connection, None)
623627 return bool(self._connections)
624628
625629 def send_command(
626 self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
630 self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
627631 ):
628632 """Send a command to all connected connections.
629633
4545 > ERROR server stopping
4646 * connection closed by server *
4747 """
48 import abc
4948 import fcntl
5049 import logging
5150 import struct
5352 from typing import TYPE_CHECKING, List, Optional
5453
5554 from prometheus_client import Counter
55 from zope.interface import Interface, implementer
5656
5757 from twisted.internet import task
58 from twisted.internet.tcp import Connection
5859 from twisted.protocols.basic import LineOnlyReceiver
5960 from twisted.python.failure import Failure
6061
120121 CLOSED = "closed"
121122
122123
124 class IReplicationConnection(Interface):
125 """An interface for replication connections."""
126
127 def send_command(cmd: Command):
128 """Send the command down the connection"""
129
130
131 @implementer(IReplicationConnection)
123132 class BaseReplicationStreamProtocol(LineOnlyReceiver):
124133 """Base replication protocol shared between client and server.
125134
136145 (if they send a `PING` command)
137146 """
138147
148 # The transport is going to be an ITCPTransport, but that doesn't have the
149 # (un)registerProducer methods, those are only on the implementation.
150 transport = None # type: Connection
151
139152 delimiter = b"\n"
140153
141154 # Valid commands we expect to receive
180193
181194 connected_connections.append(self) # Register connection for metrics
182195
196 assert self.transport is not None
183197 self.transport.registerProducer(self, True) # For the *Producing callbacks
184198
185199 self._send_pending_commands()
204218 logger.info(
205219 "[%s] Failed to close connection gracefully, aborting", self.id()
206220 )
221 assert self.transport is not None
207222 self.transport.abortConnection()
208223 else:
209224 if now - self.last_sent_command >= PING_TIME:
293308 def close(self):
294309 logger.warning("[%s] Closing connection", self.id())
295310 self.time_we_closed = self.clock.time_msec()
311 assert self.transport is not None
296312 self.transport.loseConnection()
297313 self.on_connection_closed()
298314
390406 def connectionLost(self, reason):
391407 logger.info("[%s] Replication connection closed: %r", self.id(), reason)
392408 if isinstance(reason, Failure):
409 assert reason.type is not None
393410 connection_close_counter.labels(reason.type.__name__).inc()
394411 else:
395412 connection_close_counter.labels(reason.__class__.__name__).inc()
494511 self.send_command(ReplicateCommand())
495512
496513
497 class AbstractConnection(abc.ABC):
498 """An interface for replication connections."""
499
500 @abc.abstractmethod
501 def send_command(self, cmd: Command):
502 """Send the command down the connection"""
503 pass
504
505
506 # This tells python that `BaseReplicationStreamProtocol` implements the
507 # interface.
508 AbstractConnection.register(BaseReplicationStreamProtocol)
509
510
511514 # The following simply registers metrics for the replication connections
512515
513516 pending_commands = LaterGauge(
1818
1919 import attr
2020 import txredisapi
21 from zope.interface import implementer
22
23 from twisted.internet.address import IPv4Address, IPv6Address
24 from twisted.internet.interfaces import IAddress, IConnector
25 from twisted.python.failure import Failure
2126
2227 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
2328 from synapse.metrics.background_process_metrics import (
3136 parse_command_from_line,
3237 )
3338 from synapse.replication.tcp.protocol import (
34 AbstractConnection,
39 IReplicationConnection,
3540 tcp_inbound_commands_counter,
3641 tcp_outbound_commands_counter,
3742 )
6166 pass
6267
6368
64 class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
69 @implementer(IReplicationConnection)
70 class RedisSubscriber(txredisapi.SubscriberProtocol):
6571 """Connection to redis subscribed to replication stream.
6672
6773 This class fulfils two functions:
7076 connection, parsing *incoming* messages into replication commands, and passing them
7177 to `ReplicationCommandHandler`
7278
73 (b) it implements the AbstractConnection API, where it sends *outgoing* commands
79 (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
7480 onto outbound_redis_connection.
7581
7682 Due to the vagaries of `txredisapi` we don't want to have a custom
252258 except Exception:
253259 logger.warning("Failed to send ping to a redis connection")
254260
261 # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
262 # it's rubbish. We add our own here.
263
264 def startedConnecting(self, connector: IConnector):
265 logger.info(
266 "Connecting to redis server %s", format_address(connector.getDestination())
267 )
268 super().startedConnecting(connector)
269
270 def clientConnectionFailed(self, connector: IConnector, reason: Failure):
271 logger.info(
272 "Connection to redis server %s failed: %s",
273 format_address(connector.getDestination()),
274 reason.value,
275 )
276 super().clientConnectionFailed(connector, reason)
277
278 def clientConnectionLost(self, connector: IConnector, reason: Failure):
279 logger.info(
280 "Connection to redis server %s lost: %s",
281 format_address(connector.getDestination()),
282 reason.value,
283 )
284 super().clientConnectionLost(connector, reason)
285
286
287 def format_address(address: IAddress) -> str:
288 if isinstance(address, (IPv4Address, IPv6Address)):
289 return "%s:%i" % (address.host, address.port)
290 return str(address)
291
255292
256293 class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
257294 """This is a reconnecting factory that connects to redis and immediately
327364 factory.continueTrying = reconnect
328365
329366 reactor = hs.get_reactor()
330 reactor.connectTCP(host, port, factory, 30)
367 reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
331368
332369 return factory.handler
1414
1515 import re
1616
17 import twisted.web.server
18
19 import synapse.api.auth
17 from synapse.api.auth import Auth
2018 from synapse.api.errors import AuthError
19 from synapse.http.site import SynapseRequest
2120 from synapse.types import UserID
2221
2322
3635 return patterns
3736
3837
39 async def assert_requester_is_admin(
40 auth: synapse.api.auth.Auth, request: twisted.web.server.Request
41 ) -> None:
38 async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
4239 """Verify that the requester is an admin user
4340
4441 Args:
45 auth: api.auth.Auth singleton
42 auth: Auth singleton
4643 request: incoming request
4744
4845 Raises:
5249 await assert_user_is_admin(auth, requester.user)
5350
5451
55 async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
52 async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
5653 """Verify that the given user is an admin user
5754
5855 Args:
59 auth: api.auth.Auth singleton
56 auth: Auth singleton
6057 user_id: user to check
6158
6259 Raises:
1616 import logging
1717 from typing import TYPE_CHECKING, Tuple
1818
19 from twisted.web.server import Request
20
2119 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
2220 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
21 from synapse.http.site import SynapseRequest
2322 from synapse.rest.admin._base import (
2423 admin_patterns,
2524 assert_requester_is_admin,
4948 self.store = hs.get_datastore()
5049 self.auth = hs.get_auth()
5150
52 async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
51 async def on_POST(
52 self, request: SynapseRequest, room_id: str
53 ) -> Tuple[int, JsonDict]:
5354 requester = await self.auth.get_user_by_req(request)
5455 await assert_user_is_admin(self.auth, requester.user)
5556
7475 self.store = hs.get_datastore()
7576 self.auth = hs.get_auth()
7677
77 async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
78 async def on_POST(
79 self, request: SynapseRequest, user_id: str
80 ) -> Tuple[int, JsonDict]:
7881 requester = await self.auth.get_user_by_req(request)
7982 await assert_user_is_admin(self.auth, requester.user)
8083
102105 self.auth = hs.get_auth()
103106
104107 async def on_POST(
105 self, request: Request, server_name: str, media_id: str
108 self, request: SynapseRequest, server_name: str, media_id: str
106109 ) -> Tuple[int, JsonDict]:
107110 requester = await self.auth.get_user_by_req(request)
108111 await assert_user_is_admin(self.auth, requester.user)
126129 self.store = hs.get_datastore()
127130 self.auth = hs.get_auth()
128131
129 async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
132 async def on_POST(
133 self, request: SynapseRequest, media_id: str
134 ) -> Tuple[int, JsonDict]:
130135 requester = await self.auth.get_user_by_req(request)
131136 await assert_user_is_admin(self.auth, requester.user)
132137
147152 self.store = hs.get_datastore()
148153 self.auth = hs.get_auth()
149154
150 async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
155 async def on_GET(
156 self, request: SynapseRequest, room_id: str
157 ) -> Tuple[int, JsonDict]:
151158 requester = await self.auth.get_user_by_req(request)
152159 is_admin = await self.auth.is_server_admin(requester.user)
153160 if not is_admin:
165172 self.media_repository = hs.get_media_repository()
166173 self.auth = hs.get_auth()
167174
168 async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
175 async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
169176 await assert_requester_is_admin(self.auth, request)
170177
171178 before_ts = parse_integer(request, "before_ts", required=True)
188195 self.media_repository = hs.get_media_repository()
189196
190197 async def on_DELETE(
191 self, request: Request, server_name: str, media_id: str
198 self, request: SynapseRequest, server_name: str, media_id: str
192199 ) -> Tuple[int, JsonDict]:
193200 await assert_requester_is_admin(self.auth, request)
194201
217224 self.server_name = hs.hostname
218225 self.media_repository = hs.get_media_repository()
219226
220 async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
227 async def on_POST(
228 self, request: SynapseRequest, server_name: str
229 ) -> Tuple[int, JsonDict]:
221230 await assert_requester_is_admin(self.auth, request)
222231
223232 before_ts = parse_integer(request, "before_ts", required=True)
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import TYPE_CHECKING, Tuple
15
1416 from synapse.http.servlet import (
1517 RestServlet,
1618 assert_params_in_dict,
1719 parse_json_object_from_request,
1820 )
21 from synapse.http.site import SynapseRequest
1922 from synapse.rest.admin import assert_requester_is_admin
2023 from synapse.rest.admin._base import admin_patterns
24 from synapse.types import JsonDict
25
26 if TYPE_CHECKING:
27 from synapse.server import HomeServer
2128
2229
2330 class PurgeRoomServlet(RestServlet):
3542
3643 PATTERNS = admin_patterns("/purge_room$")
3744
38 def __init__(self, hs):
39 """
40 Args:
41 hs (synapse.server.HomeServer): server
42 """
45 def __init__(self, hs: "HomeServer"):
4346 self.hs = hs
4447 self.auth = hs.get_auth()
4548 self.pagination_handler = hs.get_pagination_handler()
4649
47 async def on_POST(self, request):
50 async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
4851 await assert_requester_is_admin(self.auth, request)
4952
5053 body = parse_json_object_from_request(request)
684684 results["events_after"], time_now
685685 )
686686 results["state"] = await self._event_serializer.serialize_events(
687 results["state"], time_now
687 results["state"],
688 time_now,
689 # No need to bundle aggregations for state events
690 bundle_aggregations=False,
688691 )
689692
690693 return 200, results
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import TYPE_CHECKING, Optional, Tuple
15
1416 from synapse.api.constants import EventTypes
1517 from synapse.api.errors import SynapseError
18 from synapse.http.server import HttpServer
1619 from synapse.http.servlet import (
1720 RestServlet,
1821 assert_params_in_dict,
1922 parse_json_object_from_request,
2023 )
24 from synapse.http.site import SynapseRequest
2125 from synapse.rest.admin import assert_requester_is_admin
2226 from synapse.rest.admin._base import admin_patterns
2327 from synapse.rest.client.transactions import HttpTransactionCache
24 from synapse.types import UserID
28 from synapse.types import JsonDict, UserID
29
30 if TYPE_CHECKING:
31 from synapse.server import HomeServer
2532
2633
2734 class SendServerNoticeServlet(RestServlet):
4350 }
4451 """
4552
46 def __init__(self, hs):
47 """
48 Args:
49 hs (synapse.server.HomeServer): server
50 """
53 def __init__(self, hs: "HomeServer"):
5154 self.hs = hs
5255 self.auth = hs.get_auth()
5356 self.txns = HttpTransactionCache(hs)
5457 self.snm = hs.get_server_notices_manager()
5558
56 def register(self, json_resource):
59 def register(self, json_resource: HttpServer):
5760 PATTERN = "/send_server_notice"
5861 json_resource.register_paths(
5962 "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
6568 self.__class__.__name__,
6669 )
6770
68 async def on_POST(self, request, txn_id=None):
71 async def on_POST(
72 self, request: SynapseRequest, txn_id: Optional[str] = None
73 ) -> Tuple[int, JsonDict]:
6974 await assert_requester_is_admin(self.auth, request)
7075 body = parse_json_object_from_request(request)
7176 assert_params_in_dict(body, ("user_id", "content"))
8994
9095 return 200, {"event_id": event.event_id}
9196
92 def on_PUT(self, request, txn_id):
97 def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
9398 return self.txns.fetch_or_execute_request(
9499 request, self.on_POST, request, txn_id
95100 )
268268 target_user.to_string(), False, requester, by_admin=True
269269 )
270270 elif not deactivate and user["deactivated"]:
271 if "password" not in body:
271 if (
272 "password" not in body
273 and self.hs.config.password_localdb_enabled
274 ):
272275 raise SynapseError(
273276 400, "Must provide a password to re-activate an account."
274277 )
1313 # limitations under the License.
1414
1515 import logging
16 import re
1617 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
1718
1819 from synapse.api.errors import Codes, LoginError, SynapseError
1920 from synapse.api.ratelimiting import Ratelimiter
21 from synapse.api.urls import CLIENT_API_PREFIX
2022 from synapse.appservice import ApplicationService
2123 from synapse.handlers.sso import SsoIdentityProvider
2224 from synapse.http import get_request_uri
9395 flows.append({"type": LoginRestServlet.CAS_TYPE})
9496
9597 if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
96 sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
98 sso_flow = {
99 "type": LoginRestServlet.SSO_TYPE,
100 "identity_providers": [
101 _get_auth_flow_dict_for_idp(
102 idp,
103 )
104 for idp in self._sso_handler.get_identity_providers().values()
105 ],
106 } # type: JsonDict
97107
98108 if self._msc2858_enabled:
109 # backwards-compatibility support for clients which don't
110 # support the stable API yet
99111 sso_flow["org.matrix.msc2858.identity_providers"] = [
100 _get_auth_flow_dict_for_idp(idp)
112 _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
101113 for idp in self._sso_handler.get_identity_providers().values()
102114 ]
103115
218230 callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
219231 create_non_existent_users: bool = False,
220232 ratelimit: bool = True,
233 auth_provider_id: Optional[str] = None,
221234 ) -> Dict[str, str]:
222235 """Called when we've successfully authed the user and now need to
223236 actually login them in (e.g. create devices). This gets called on
233246 create_non_existent_users: Whether to create the user if they don't
234247 exist. Defaults to False.
235248 ratelimit: Whether to ratelimit the login request.
249 auth_provider_id: The SSO IdP the user used, if any (just used for the
250 prometheus metrics).
236251
237252 Returns:
238253 result: Dictionary of account information after successful login.
255270 device_id = login_submission.get("device_id")
256271 initial_display_name = login_submission.get("initial_device_display_name")
257272 device_id, access_token = await self.registration_handler.register_device(
258 user_id, device_id, initial_display_name
273 user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
259274 )
260275
261276 result = {
282297 """
283298 token = login_submission["token"]
284299 auth_handler = self.auth_handler
285 user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
286 token
287 )
300 res = await auth_handler.validate_short_term_login_token(token)
288301
289302 return await self._complete_login(
290 user_id, login_submission, self.auth_handler._sso_login_callback
303 res.user_id,
304 login_submission,
305 self.auth_handler._sso_login_callback,
306 auth_provider_id=res.auth_provider_id,
291307 )
292308
293309 async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
326342 return result
327343
328344
329 def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
345 def _get_auth_flow_dict_for_idp(
346 idp: SsoIdentityProvider, use_unstable_brands: bool = False
347 ) -> JsonDict:
330348 """Return an entry for the login flow dict
331349
332350 Returns an entry suitable for inclusion in "identity_providers" in the
333351 response to GET /_matrix/client/r0/login
352
353 Args:
354 idp: the identity provider to describe
355 use_unstable_brands: whether we should use brand identifiers suitable
356 for the unstable API
334357 """
335358 e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
336359 if idp.idp_icon:
337360 e["icon"] = idp.idp_icon
338361 if idp.idp_brand:
339362 e["brand"] = idp.idp_brand
363 # use the stable brand identifier if the unstable identifier isn't defined.
364 if use_unstable_brands and idp.unstable_idp_brand:
365 e["brand"] = idp.unstable_idp_brand
340366 return e
341367
342368
343369 class SsoRedirectServlet(RestServlet):
344 PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
370 PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
371 re.compile(
372 "^"
373 + CLIENT_API_PREFIX
374 + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
375 )
376 ]
345377
346378 def __init__(self, hs: "HomeServer"):
347379 # make sure that the relevant handlers are instantiated, so that they
359391 def register(self, http_server: HttpServer) -> None:
360392 super().register(http_server)
361393 if self._msc2858_enabled:
362 # expose additional endpoint for MSC2858 support
394 # expose additional endpoint for MSC2858 support: backwards-compat support
395 # for clients which don't yet support the stable endpoints.
363396 http_server.register_paths(
364397 "GET",
365398 client_patterns(
670670 results["events_after"], time_now
671671 )
672672 results["state"] = await self._event_serializer.serialize_events(
673 results["state"], time_now
673 results["state"],
674 time_now,
675 # No need to bundle aggregations for state events
676 bundle_aggregations=False,
674677 )
675678
676679 return 200, results
3131 assert_params_in_dict,
3232 parse_json_object_from_request,
3333 )
34 from synapse.http.site import SynapseRequest
3435 from synapse.types import GroupID, JsonDict
3536
3637 from ._base import client_patterns
6970 self.groups_handler = hs.get_groups_local_handler()
7071
7172 @_validate_group_id
72 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
73 async def on_GET(
74 self, request: SynapseRequest, group_id: str
75 ) -> Tuple[int, JsonDict]:
7376 requester = await self.auth.get_user_by_req(request, allow_guest=True)
7477 requester_user_id = requester.user.to_string()
7578
8083 return 200, group_description
8184
8285 @_validate_group_id
83 async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
86 async def on_POST(
87 self, request: SynapseRequest, group_id: str
88 ) -> Tuple[int, JsonDict]:
8489 requester = await self.auth.get_user_by_req(request)
8590 requester_user_id = requester.user.to_string()
8691
110115 self.groups_handler = hs.get_groups_local_handler()
111116
112117 @_validate_group_id
113 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
118 async def on_GET(
119 self, request: SynapseRequest, group_id: str
120 ) -> Tuple[int, JsonDict]:
114121 requester = await self.auth.get_user_by_req(request, allow_guest=True)
115122 requester_user_id = requester.user.to_string()
116123
143150
144151 @_validate_group_id
145152 async def on_PUT(
146 self, request: Request, group_id: str, category_id: Optional[str], room_id: str
153 self,
154 request: SynapseRequest,
155 group_id: str,
156 category_id: Optional[str],
157 room_id: str,
147158 ):
148159 requester = await self.auth.get_user_by_req(request)
149160 requester_user_id = requester.user.to_string()
175186
176187 @_validate_group_id
177188 async def on_DELETE(
178 self, request: Request, group_id: str, category_id: str, room_id: str
189 self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
179190 ):
180191 requester = await self.auth.get_user_by_req(request)
181192 requester_user_id = requester.user.to_string()
205216
206217 @_validate_group_id
207218 async def on_GET(
208 self, request: Request, group_id: str, category_id: str
219 self, request: SynapseRequest, group_id: str, category_id: str
209220 ) -> Tuple[int, JsonDict]:
210221 requester = await self.auth.get_user_by_req(request, allow_guest=True)
211222 requester_user_id = requester.user.to_string()
218229
219230 @_validate_group_id
220231 async def on_PUT(
221 self, request: Request, group_id: str, category_id: str
232 self, request: SynapseRequest, group_id: str, category_id: str
222233 ) -> Tuple[int, JsonDict]:
223234 requester = await self.auth.get_user_by_req(request)
224235 requester_user_id = requester.user.to_string()
246257
247258 @_validate_group_id
248259 async def on_DELETE(
249 self, request: Request, group_id: str, category_id: str
260 self, request: SynapseRequest, group_id: str, category_id: str
250261 ) -> Tuple[int, JsonDict]:
251262 requester = await self.auth.get_user_by_req(request)
252263 requester_user_id = requester.user.to_string()
273284 self.groups_handler = hs.get_groups_local_handler()
274285
275286 @_validate_group_id
276 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
287 async def on_GET(
288 self, request: SynapseRequest, group_id: str
289 ) -> Tuple[int, JsonDict]:
277290 requester = await self.auth.get_user_by_req(request, allow_guest=True)
278291 requester_user_id = requester.user.to_string()
279292
297310
298311 @_validate_group_id
299312 async def on_GET(
300 self, request: Request, group_id: str, role_id: str
313 self, request: SynapseRequest, group_id: str, role_id: str
301314 ) -> Tuple[int, JsonDict]:
302315 requester = await self.auth.get_user_by_req(request, allow_guest=True)
303316 requester_user_id = requester.user.to_string()
310323
311324 @_validate_group_id
312325 async def on_PUT(
313 self, request: Request, group_id: str, role_id: str
326 self, request: SynapseRequest, group_id: str, role_id: str
314327 ) -> Tuple[int, JsonDict]:
315328 requester = await self.auth.get_user_by_req(request)
316329 requester_user_id = requester.user.to_string()
338351
339352 @_validate_group_id
340353 async def on_DELETE(
341 self, request: Request, group_id: str, role_id: str
354 self, request: SynapseRequest, group_id: str, role_id: str
342355 ) -> Tuple[int, JsonDict]:
343356 requester = await self.auth.get_user_by_req(request)
344357 requester_user_id = requester.user.to_string()
365378 self.groups_handler = hs.get_groups_local_handler()
366379
367380 @_validate_group_id
368 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
381 async def on_GET(
382 self, request: SynapseRequest, group_id: str
383 ) -> Tuple[int, JsonDict]:
369384 requester = await self.auth.get_user_by_req(request, allow_guest=True)
370385 requester_user_id = requester.user.to_string()
371386
398413
399414 @_validate_group_id
400415 async def on_PUT(
401 self, request: Request, group_id: str, role_id: Optional[str], user_id: str
416 self,
417 request: SynapseRequest,
418 group_id: str,
419 role_id: Optional[str],
420 user_id: str,
402421 ) -> Tuple[int, JsonDict]:
403422 requester = await self.auth.get_user_by_req(request)
404423 requester_user_id = requester.user.to_string()
430449
431450 @_validate_group_id
432451 async def on_DELETE(
433 self, request: Request, group_id: str, role_id: str, user_id: str
452 self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
434453 ):
435454 requester = await self.auth.get_user_by_req(request)
436455 requester_user_id = requester.user.to_string()
457476 self.groups_handler = hs.get_groups_local_handler()
458477
459478 @_validate_group_id
460 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
479 async def on_GET(
480 self, request: SynapseRequest, group_id: str
481 ) -> Tuple[int, JsonDict]:
461482 requester = await self.auth.get_user_by_req(request, allow_guest=True)
462483 requester_user_id = requester.user.to_string()
463484
480501 self.groups_handler = hs.get_groups_local_handler()
481502
482503 @_validate_group_id
483 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
504 async def on_GET(
505 self, request: SynapseRequest, group_id: str
506 ) -> Tuple[int, JsonDict]:
484507 requester = await self.auth.get_user_by_req(request, allow_guest=True)
485508 requester_user_id = requester.user.to_string()
486509
503526 self.groups_handler = hs.get_groups_local_handler()
504527
505528 @_validate_group_id
506 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
529 async def on_GET(
530 self, request: SynapseRequest, group_id: str
531 ) -> Tuple[int, JsonDict]:
507532 requester = await self.auth.get_user_by_req(request)
508533 requester_user_id = requester.user.to_string()
509534
525550 self.groups_handler = hs.get_groups_local_handler()
526551
527552 @_validate_group_id
528 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
553 async def on_PUT(
554 self, request: SynapseRequest, group_id: str
555 ) -> Tuple[int, JsonDict]:
529556 requester = await self.auth.get_user_by_req(request)
530557 requester_user_id = requester.user.to_string()
531558
553580 self.groups_handler = hs.get_groups_local_handler()
554581 self.server_name = hs.hostname
555582
556 async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
583 async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
557584 requester = await self.auth.get_user_by_req(request)
558585 requester_user_id = requester.user.to_string()
559586
597624
598625 @_validate_group_id
599626 async def on_PUT(
600 self, request: Request, group_id: str, room_id: str
627 self, request: SynapseRequest, group_id: str, room_id: str
601628 ) -> Tuple[int, JsonDict]:
602629 requester = await self.auth.get_user_by_req(request)
603630 requester_user_id = requester.user.to_string()
614641
615642 @_validate_group_id
616643 async def on_DELETE(
617 self, request: Request, group_id: str, room_id: str
644 self, request: SynapseRequest, group_id: str, room_id: str
618645 ) -> Tuple[int, JsonDict]:
619646 requester = await self.auth.get_user_by_req(request)
620647 requester_user_id = requester.user.to_string()
645672
646673 @_validate_group_id
647674 async def on_PUT(
648 self, request: Request, group_id: str, room_id: str, config_key: str
675 self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
649676 ):
650677 requester = await self.auth.get_user_by_req(request)
651678 requester_user_id = requester.user.to_string()
677704 self.is_mine_id = hs.is_mine_id
678705
679706 @_validate_group_id
680 async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
707 async def on_PUT(
708 self, request: SynapseRequest, group_id, user_id
709 ) -> Tuple[int, JsonDict]:
681710 requester = await self.auth.get_user_by_req(request)
682711 requester_user_id = requester.user.to_string()
683712
707736 self.groups_handler = hs.get_groups_local_handler()
708737
709738 @_validate_group_id
710 async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
739 async def on_PUT(
740 self, request: SynapseRequest, group_id, user_id
741 ) -> Tuple[int, JsonDict]:
711742 requester = await self.auth.get_user_by_req(request)
712743 requester_user_id = requester.user.to_string()
713744
734765 self.groups_handler = hs.get_groups_local_handler()
735766
736767 @_validate_group_id
737 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
768 async def on_PUT(
769 self, request: SynapseRequest, group_id: str
770 ) -> Tuple[int, JsonDict]:
738771 requester = await self.auth.get_user_by_req(request)
739772 requester_user_id = requester.user.to_string()
740773
761794 self.groups_handler = hs.get_groups_local_handler()
762795
763796 @_validate_group_id
764 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
797 async def on_PUT(
798 self, request: SynapseRequest, group_id: str
799 ) -> Tuple[int, JsonDict]:
765800 requester = await self.auth.get_user_by_req(request)
766801 requester_user_id = requester.user.to_string()
767802
788823 self.groups_handler = hs.get_groups_local_handler()
789824
790825 @_validate_group_id
791 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
826 async def on_PUT(
827 self, request: SynapseRequest, group_id: str
828 ) -> Tuple[int, JsonDict]:
792829 requester = await self.auth.get_user_by_req(request)
793830 requester_user_id = requester.user.to_string()
794831
815852 self.store = hs.get_datastore()
816853
817854 @_validate_group_id
818 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
855 async def on_PUT(
856 self, request: SynapseRequest, group_id: str
857 ) -> Tuple[int, JsonDict]:
819858 requester = await self.auth.get_user_by_req(request)
820859 requester_user_id = requester.user.to_string()
821860
838877 self.store = hs.get_datastore()
839878 self.groups_handler = hs.get_groups_local_handler()
840879
841 async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
880 async def on_GET(
881 self, request: SynapseRequest, user_id: str
882 ) -> Tuple[int, JsonDict]:
842883 await self.auth.get_user_by_req(request, allow_guest=True)
843884
844885 result = await self.groups_handler.get_publicised_groups_for_user(user_id)
858899 self.store = hs.get_datastore()
859900 self.groups_handler = hs.get_groups_local_handler()
860901
861 async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
902 async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
862903 await self.auth.get_user_by_req(request, allow_guest=True)
863904
864905 content = parse_json_object_from_request(request)
880921 self.clock = hs.get_clock()
881922 self.groups_handler = hs.get_groups_local_handler()
882923
883 async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
924 async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
884925 requester = await self.auth.get_user_by_req(request, allow_guest=True)
885926 requester_user_id = requester.user.to_string()
886927
1919 from twisted.web.server import Request
2020
2121 from synapse.http.server import DirectServeJsonResource, respond_with_json
22 from synapse.http.site import SynapseRequest
2223
2324 if TYPE_CHECKING:
2425 from synapse.app.homeserver import HomeServer
3435 self.auth = hs.get_auth()
3536 self.limits_dict = {"m.upload.size": config.max_upload_size}
3637
37 async def _async_render_GET(self, request: Request) -> None:
38 async def _async_render_GET(self, request: SynapseRequest) -> None:
3839 await self.auth.get_user_by_req(request)
3940 respond_with_json(request, 200, self.limits_dict, send_cors=True)
4041
3434 from synapse.config._base import ConfigError
3535 from synapse.logging.context import defer_to_thread
3636 from synapse.metrics.background_process_metrics import run_as_background_process
37 from synapse.types import UserID
3738 from synapse.util.async_helpers import Linearizer
3839 from synapse.util.retryutils import NotRetryingDestination
3940 from synapse.util.stringutils import random_string
144145 upload_name: Optional[str],
145146 content: IO,
146147 content_length: int,
147 auth_user: str,
148 auth_user: UserID,
148149 ) -> str:
149150 """Store uploaded content for a local user and return the mxc URL
150151
3838 respond_with_json_bytes,
3939 )
4040 from synapse.http.servlet import parse_integer, parse_string
41 from synapse.http.site import SynapseRequest
4142 from synapse.logging.context import make_deferred_yieldable, run_in_background
4243 from synapse.metrics.background_process_metrics import run_as_background_process
4344 from synapse.rest.media.v1._base import get_filename_from_headers
184185 request.setHeader(b"Allow", b"OPTIONS, GET")
185186 respond_with_json(request, 200, {}, send_cors=True)
186187
187 async def _async_render_GET(self, request: Request) -> None:
188 async def _async_render_GET(self, request: SynapseRequest) -> None:
188189
189190 # XXX: if get_user_by_req fails, what should we do in an async render?
190191 requester = await self.auth.get_user_by_req(request)
9595 def _resize(self, width: int, height: int) -> Image:
9696 # 1-bit or 8-bit color palette images need converting to RGB
9797 # otherwise they will be scaled using nearest neighbour which
98 # looks awful
99 if self.image.mode in ["1", "P"]:
100 self.image = self.image.convert("RGB")
98 # looks awful.
99 #
100 # If the image has transparency, use RGBA instead.
101 if self.image.mode in ["1", "L", "P"]:
102 mode = "RGB"
103 if self.image.info.get("transparency", None) is not None:
104 mode = "RGBA"
105 self.image = self.image.convert(mode)
101106 return self.image.resize((width, height), Image.ANTIALIAS)
102107
103108 def scale(self, width: int, height: int, output_type: str) -> BytesIO:
2121 from synapse.api.errors import Codes, SynapseError
2222 from synapse.http.server import DirectServeJsonResource, respond_with_json
2323 from synapse.http.servlet import parse_string
24 from synapse.http.site import SynapseRequest
2425 from synapse.rest.media.v1.media_storage import SpamMediaException
2526
2627 if TYPE_CHECKING:
4849 async def _async_render_OPTIONS(self, request: Request) -> None:
4950 respond_with_json(request, 200, {}, send_cors=True)
5051
51 async def _async_render_POST(self, request: Request) -> None:
52 async def _async_render_POST(self, request: SynapseRequest) -> None:
5253 requester = await self.auth.get_user_by_req(request)
5354 # TODO: The checks here are a bit late. The content will have
5455 # already been uploaded to a tmp file at this point
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 from typing import TYPE_CHECKING
17
1618 from synapse.http.server import DirectServeHtmlResource
19
20 if TYPE_CHECKING:
21 from synapse.server import HomeServer
1722
1823
1924 class SAML2ResponseResource(DirectServeHtmlResource):
2126
2227 isLeaf = 1
2328
24 def __init__(self, hs):
29 def __init__(self, hs: "HomeServer"):
2530 super().__init__()
2631 self._saml_handler = hs.get_saml_handler()
32 self._sso_handler = hs.get_sso_handler()
2733
2834 async def _async_render_GET(self, request):
2935 # We're not expecting any GET request on that resource if everything goes right,
3036 # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
3137 # In this case, just tell the user that something went wrong and they should
3238 # try to authenticate again.
33 self._saml_handler._render_error(
39 self._sso_handler.render_error(
3440 request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
3541 )
3642
3535 cast,
3636 )
3737
38 import twisted.internet.base
3938 import twisted.internet.tcp
4039 from twisted.internet import defer
4140 from twisted.mail.smtp import sendmail
129128 from synapse.state import StateHandler, StateResolutionHandler
130129 from synapse.storage import Databases, DataStore, Storage
131130 from synapse.streams.events import EventSources
132 from synapse.types import DomainSpecificString
131 from synapse.types import DomainSpecificString, ISynapseReactor
133132 from synapse.util import Clock
134133 from synapse.util.distributor import Distributor
135134 from synapse.util.ratelimitutils import FederationRateLimiter
290289 for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
291290 getattr(self, "get_" + i + "_handler")()
292291
293 def get_reactor(self) -> twisted.internet.base.ReactorBase:
292 def get_reactor(self) -> ISynapseReactor:
294293 """
295294 Fetch the Twisted reactor in use by this HomeServer.
296295 """
351350
352351 @cache_in_self
353352 def get_http_client_context_factory(self) -> IPolicyForHTTPS:
354 return (
355 InsecureInterceptableContextFactory()
356 if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
357 else RegularPolicyForHTTPS()
358 )
353 if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
354 return InsecureInterceptableContextFactory()
355 return RegularPolicyForHTTPS()
359356
360357 @cache_in_self
361358 def get_simple_http_client(self) -> SimpleHttpClient:
5353 ) # type: LruCache[str, List[Tuple[str, int]]]
5454
5555 async def get_auth_chain(
56 self, event_ids: Collection[str], include_given: bool = False
56 self, room_id: str, event_ids: Collection[str], include_given: bool = False
5757 ) -> List[EventBase]:
5858 """Get auth events for given event_ids. The events *must* be state events.
5959
6060 Args:
61 room_id: The room the event is in.
6162 event_ids: state events
6263 include_given: include the given events in result
6364
6566 list of events
6667 """
6768 event_ids = await self.get_auth_chain_ids(
68 event_ids, include_given=include_given
69 room_id, event_ids, include_given=include_given
6970 )
7071 return await self.get_events_as_list(event_ids)
7172
7273 async def get_auth_chain_ids(
7374 self,
75 room_id: str,
7476 event_ids: Collection[str],
7577 include_given: bool = False,
7678 ) -> List[str]:
7779 """Get auth events for given event_ids. The events *must* be state events.
7880
7981 Args:
82 room_id: The room the event is in.
8083 event_ids: state events
8184 include_given: include the given events in result
8285
8386 Returns:
84 An awaitable which resolve to a list of event_ids
85 """
87 list of event_ids
88 """
89
90 # Check if we have indexed the room so we can use the chain cover
91 # algorithm.
92 room = await self.get_room(room_id)
93 if room["has_auth_chain_index"]:
94 try:
95 return await self.db_pool.runInteraction(
96 "get_auth_chain_ids_chains",
97 self._get_auth_chain_ids_using_cover_index_txn,
98 room_id,
99 event_ids,
100 include_given,
101 )
102 except _NoChainCoverIndex:
103 # For whatever reason we don't actually have a chain cover index
104 # for the events in question, so we fall back to the old method.
105 pass
106
86107 return await self.db_pool.runInteraction(
87108 "get_auth_chain_ids",
88109 self._get_auth_chain_ids_txn,
90111 include_given,
91112 )
92113
114 def _get_auth_chain_ids_using_cover_index_txn(
115 self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
116 ) -> List[str]:
117 """Calculates the auth chain IDs using the chain index."""
118
119 # First we look up the chain ID/sequence numbers for the given events.
120
121 initial_events = set(event_ids)
122
123 # All the events that we've found that are reachable from the events.
124 seen_events = set() # type: Set[str]
125
126 # A map from chain ID to max sequence number of the given events.
127 event_chains = {} # type: Dict[int, int]
128
129 sql = """
130 SELECT event_id, chain_id, sequence_number
131 FROM event_auth_chains
132 WHERE %s
133 """
134 for batch in batch_iter(initial_events, 1000):
135 clause, args = make_in_list_sql_clause(
136 txn.database_engine, "event_id", batch
137 )
138 txn.execute(sql % (clause,), args)
139
140 for event_id, chain_id, sequence_number in txn:
141 seen_events.add(event_id)
142 event_chains[chain_id] = max(
143 sequence_number, event_chains.get(chain_id, 0)
144 )
145
146 # Check that we actually have a chain ID for all the events.
147 events_missing_chain_info = initial_events.difference(seen_events)
148 if events_missing_chain_info:
149 # This can happen due to e.g. downgrade/upgrade of the server. We
150 # raise an exception and fall back to the previous algorithm.
151 logger.info(
152 "Unexpectedly found that events don't have chain IDs in room %s: %s",
153 room_id,
154 events_missing_chain_info,
155 )
156 raise _NoChainCoverIndex(room_id)
157
158 # Now we look up all links for the chains we have, adding chains that
159 # are reachable from any event.
160 sql = """
161 SELECT
162 origin_chain_id, origin_sequence_number,
163 target_chain_id, target_sequence_number
164 FROM event_auth_chain_links
165 WHERE %s
166 """
167
168 # A map from chain ID to max sequence number *reachable* from any event ID.
169 chains = {} # type: Dict[int, int]
170
171 # Add all linked chains reachable from initial set of chains.
172 for batch in batch_iter(event_chains, 1000):
173 clause, args = make_in_list_sql_clause(
174 txn.database_engine, "origin_chain_id", batch
175 )
176 txn.execute(sql % (clause,), args)
177
178 for (
179 origin_chain_id,
180 origin_sequence_number,
181 target_chain_id,
182 target_sequence_number,
183 ) in txn:
184 # chains are only reachable if the origin sequence number of
185 # the link is less than the max sequence number in the
186 # origin chain.
187 if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
188 chains[target_chain_id] = max(
189 target_sequence_number,
190 chains.get(target_chain_id, 0),
191 )
192
193 # Add the initial set of chains, excluding the sequence corresponding to
194 # initial event.
195 for chain_id, seq_no in event_chains.items():
196 chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
197
198 # Now for each chain we figure out the maximum sequence number reachable
199 # from *any* event ID. Events with a sequence less than that are in the
200 # auth chain.
201 if include_given:
202 results = initial_events
203 else:
204 results = set()
205
206 if isinstance(self.database_engine, PostgresEngine):
207 # We can use `execute_values` to efficiently fetch the gaps when
208 # using postgres.
209 sql = """
210 SELECT event_id
211 FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
212 WHERE
213 c.chain_id = l.chain_id
214 AND sequence_number <= max_seq
215 """
216
217 rows = txn.execute_values(sql, chains.items())
218 results.update(r for r, in rows)
219 else:
220 # For SQLite we just fall back to doing a noddy for loop.
221 sql = """
222 SELECT event_id FROM event_auth_chains
223 WHERE chain_id = ? AND sequence_number <= ?
224 """
225 for chain_id, max_no in chains.items():
226 txn.execute(sql, (chain_id, max_no))
227 results.update(r for r, in txn)
228
229 return list(results)
230
93231 def _get_auth_chain_ids_txn(
94232 self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
95233 ) -> List[str]:
234 """Calculates the auth chain IDs.
235
236 This is used when we don't have a cover index for the room.
237 """
96238 if include_given:
97239 results = set(event_ids)
98240 else:
132132 self.db_pool.updates.register_background_update_handler(
133133 "chain_cover",
134134 self._chain_cover_index,
135 )
136
137 self.db_pool.updates.register_background_update_handler(
138 "purged_chain_cover",
139 self._purged_chain_cover_index,
135140 )
136141
137142 async def _background_reindex_fields_sender(self, progress, batch_size):
931936 processed_count=count,
932937 finished_room_map=finished_rooms,
933938 )
939
940 async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
941 """
942 A background updates that iterates over the chain cover and deletes the
943 chain cover for events that have been purged.
944
945 This may be due to fully purging a room or via setting a retention policy.
946 """
947 current_event_id = progress.get("current_event_id", "")
948
949 def purged_chain_cover_txn(txn) -> int:
950 # The event ID from events will be null if the chain ID / sequence
951 # number points to a purged event.
952 sql = """
953 SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
954 FROM event_auth_chains
955 LEFT JOIN events AS e USING (event_id)
956 WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
957 """
958 txn.execute(sql, (current_event_id, batch_size))
959
960 rows = txn.fetchall()
961 if not rows:
962 return 0
963
964 # The event IDs and chain IDs / sequence numbers where the event has
965 # been purged.
966 unreferenced_event_ids = []
967 unreferenced_chain_id_tuples = []
968 event_id = ""
969 for event_id, chain_id, sequence_number, has_event in rows:
970 if not has_event:
971 unreferenced_event_ids.append((event_id,))
972 unreferenced_chain_id_tuples.append((chain_id, sequence_number))
973
974 # Delete the unreferenced auth chains from event_auth_chain_links and
975 # event_auth_chains.
976 txn.executemany(
977 """
978 DELETE FROM event_auth_chains WHERE event_id = ?
979 """,
980 unreferenced_event_ids,
981 )
982 # We should also delete matching target_*, but there is no index on
983 # target_chain_id. Hopefully any purged events are due to a room
984 # being fully purged and they will be removed from the origin_*
985 # searches.
986 txn.executemany(
987 """
988 DELETE FROM event_auth_chain_links WHERE
989 origin_chain_id = ? AND origin_sequence_number = ?
990 """,
991 unreferenced_chain_id_tuples,
992 )
993
994 progress = {
995 "current_event_id": event_id,
996 }
997
998 self.db_pool.updates._background_update_progress_txn(
999 txn, "purged_chain_cover", progress
1000 )
1001
1002 return len(rows)
1003
1004 result = await self.db_pool.runInteraction(
1005 "_purged_chain_cover_index",
1006 purged_chain_cover_txn,
1007 )
1008
1009 if not result:
1010 await self.db_pool.updates._end_background_update("purged_chain_cover")
1011
1012 return result
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import itertools
14
1515 import logging
1616 import threading
1717 from collections import namedtuple
10431043 Returns:
10441044 set[str]: The events we have already seen.
10451045 """
1046 results = set()
1046 # if the event cache contains the event, obviously we've seen it.
1047 results = {x for x in event_ids if self._get_event_cache.contains(x)}
10471048
10481049 def have_seen_events_txn(txn, chunk):
10491050 sql = "SELECT event_id FROM events as e WHERE "
10511052 txn.database_engine, "e.event_id", chunk
10521053 )
10531054 txn.execute(sql + clause, args)
1054 for (event_id,) in txn:
1055 results.add(event_id)
1056
1057 # break the input up into chunks of 100
1058 input_iterator = iter(event_ids)
1059 for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
1055 results.update(row[0] for row in txn)
1056
1057 for chunk in batch_iter((x for x in event_ids if x not in results), 100):
10601058 await self.db_pool.runInteraction(
10611059 "have_seen_events", have_seen_events_txn, chunk
10621060 )
330330 txn.executemany(
331331 """
332332 DELETE FROM event_auth_chain_links WHERE
333 (origin_chain_id = ? AND origin_sequence_number = ?) OR
334 (target_chain_id = ? AND target_sequence_number = ?)
333 origin_chain_id = ? AND origin_sequence_number = ?
335334 """,
336 (
337 (chain_id, seq_num, chain_id, seq_num)
338 for (chain_id, seq_num) in referenced_chain_id_tuples
339 ),
335 referenced_chain_id_tuples,
340336 )
341337
342338 # Now we delete tables which lack an index on room_id but have one on event_id
1515 # limitations under the License.
1616 import logging
1717 import re
18 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
18 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1919
2020 import attr
2121
15091509 async def user_delete_access_tokens(
15101510 self,
15111511 user_id: str,
1512 except_token_id: Optional[str] = None,
1512 except_token_id: Optional[int] = None,
15131513 device_id: Optional[str] = None,
15141514 ) -> List[Tuple[str, int, Optional[str]]]:
15151515 """
15321532
15331533 items = keyvalues.items()
15341534 where_clause = " AND ".join(k + " = ?" for k, _ in items)
1535 values = [v for _, v in items]
1535 values = [v for _, v in items] # type: List[Union[str, int]]
15361536 if except_token_id:
15371537 where_clause += " AND id != ?"
15381538 values.append(except_token_id)
0 /* Copyright 2021 The Matrix.org Foundation C.I.C
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
16 (5910, 'purged_chain_cover', '{}');
349349
350350 self.db_pool.simple_upsert_many_txn(
351351 txn,
352 "destination_rooms",
353 ["destination", "room_id"],
354 rows,
355 ["stream_ordering"],
356 [(stream_ordering,)] * len(rows),
352 table="destination_rooms",
353 key_names=("destination", "room_id"),
354 key_values=rows,
355 value_names=["stream_ordering"],
356 value_values=[(stream_ordering,)] * len(rows),
357357 )
358358
359359 async def get_destination_last_successful_stream_ordering(
3434 import attr
3535 from signedjson.key import decode_verify_key_bytes
3636 from unpaddedbase64 import decode_base64
37 from zope.interface import Interface
38
39 from twisted.internet.interfaces import (
40 IReactorCore,
41 IReactorPluggableNameResolver,
42 IReactorTCP,
43 IReactorTime,
44 )
3745
3846 from synapse.api.errors import Codes, SynapseError
3947 from synapse.util.stringutils import parse_and_validate_server_name
6674 JsonDict = Dict[str, Any]
6775
6876
69 class Requester(
70 namedtuple(
71 "Requester",
72 [
73 "user",
74 "access_token_id",
75 "is_guest",
76 "shadow_banned",
77 "device_id",
78 "app_service",
79 "authenticated_entity",
80 ],
81 )
77 # Note that this seems to require inheriting *directly* from Interface in order
78 # for mypy-zope to realize it is an interface.
79 class ISynapseReactor(
80 IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
8281 ):
82 """The interfaces necessary for Synapse to function."""
83
84
85 @attr.s(frozen=True, slots=True)
86 class Requester:
8387 """
8488 Represents the user making a request
8589
8690 Attributes:
87 user (UserID): id of the user making the request
88 access_token_id (int|None): *ID* of the access token used for this
91 user: id of the user making the request
92 access_token_id: *ID* of the access token used for this
8993 request, or None if it came via the appservice API or similar
90 is_guest (bool): True if the user making this request is a guest user
91 shadow_banned (bool): True if the user making this request has been shadow-banned.
92 device_id (str|None): device_id which was set at authentication time
93 app_service (ApplicationService|None): the AS requesting on behalf of the user
94 """
94 is_guest: True if the user making this request is a guest user
95 shadow_banned: True if the user making this request has been shadow-banned.
96 device_id: device_id which was set at authentication time
97 app_service: the AS requesting on behalf of the user
98 authenticated_entity: The entity that authenticated when making the request.
99 This is different to the user_id when an admin user or the server is
100 "puppeting" the user.
101 """
102
103 user = attr.ib(type="UserID")
104 access_token_id = attr.ib(type=Optional[int])
105 is_guest = attr.ib(type=bool)
106 shadow_banned = attr.ib(type=bool)
107 device_id = attr.ib(type=Optional[str])
108 app_service = attr.ib(type=Optional["ApplicationService"])
109 authenticated_entity = attr.ib(type=str)
95110
96111 def serialize(self):
97112 """Converts self to a type that can be serialized as JSON, and then
140155 def create_requester(
141156 user_id: Union[str, "UserID"],
142157 access_token_id: Optional[int] = None,
143 is_guest: Optional[bool] = False,
144 shadow_banned: Optional[bool] = False,
158 is_guest: bool = False,
159 shadow_banned: bool = False,
145160 device_id: Optional[str] = None,
146161 app_service: Optional["ApplicationService"] = None,
147162 authenticated_entity: Optional[str] = None,
148 ):
163 ) -> Requester:
149164 """
150165 Create a new ``Requester`` object
151166
152167 Args:
153 user_id (str|UserID): id of the user making the request
154 access_token_id (int|None): *ID* of the access token used for this
168 user_id: id of the user making the request
169 access_token_id: *ID* of the access token used for this
155170 request, or None if it came via the appservice API or similar
156 is_guest (bool): True if the user making this request is a guest user
157 shadow_banned (bool): True if the user making this request is shadow-banned.
158 device_id (str|None): device_id which was set at authentication time
159 app_service (ApplicationService|None): the AS requesting on behalf of the user
171 is_guest: True if the user making this request is a guest user
172 shadow_banned: True if the user making this request is shadow-banned.
173 device_id: device_id which was set at authentication time
174 app_service: the AS requesting on behalf of the user
160175 authenticated_entity: The entity that authenticated when making the request.
161176 This is different to the user_id when an admin user or the server is
162177 "puppeting" the user.
7575 def callback(r):
7676 object.__setattr__(self, "_result", (True, r))
7777 while self._observers:
78 observer = self._observers.pop()
7879 try:
79 # TODO: Handle errors here.
80 self._observers.pop().callback(r)
81 except Exception:
82 pass
80 observer.callback(r)
81 except Exception as e:
82 logger.exception(
83 "%r threw an exception on .callback(%r), ignoring...",
84 observer,
85 r,
86 exc_info=e,
87 )
8388 return r
8489
8590 def errback(f):
8994 # traces when we `await` on one of the observer deferreds.
9095 f.value.__failure__ = f
9196
97 observer = self._observers.pop()
9298 try:
93 # TODO: Handle errors here.
94 self._observers.pop().errback(f)
95 except Exception:
96 pass
99 observer.errback(f)
100 except Exception as e:
101 logger.exception(
102 "%r threw an exception on .errback(%r), ignoring...",
103 observer,
104 f,
105 exc_info=e,
106 )
97107
98108 if consumeErrors:
99109 return None
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
15 from typing import Any, Callable, Dict, Generic, Optional, TypeVar
1616
1717 from twisted.internet import defer
1818
1919 from synapse.logging.context import make_deferred_yieldable, run_in_background
20 from synapse.util import Clock
2021 from synapse.util.async_helpers import ObservableDeferred
2122 from synapse.util.caches import register_cache
22
23 if TYPE_CHECKING:
24 from synapse.app.homeserver import HomeServer
2523
2624 logger = logging.getLogger(__name__)
2725
3634 used rather than trying to compute a new response.
3735 """
3836
39 def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
37 def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
4038 # Requests that haven't finished yet.
4139 self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
4240
43 self.clock = hs.get_clock()
41 self.clock = clock
4442 self.timeout_sec = timeout_ms / 1000.0
4543
4644 self._name = name
0 # -*- coding: utf-8 -*-
1 # Copyright 2020 Quentin Gliech
2 # Copyright 2021 The Matrix.org Foundation C.I.C.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """Utilities for manipulating macaroons"""
17
18 from typing import Callable, Optional
19
20 import pymacaroons
21 from pymacaroons.exceptions import MacaroonVerificationFailedException
22
23
24 def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
25 """Extracts a caveat value from a macaroon token.
26
27 Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
28 and returns the extracted value.
29
30 Args:
31 macaroon: the token
32 key: the key of the caveat to extract
33
34 Returns:
35 The extracted value
36
37 Raises:
38 MacaroonVerificationFailedException: if there are conflicting values for the
39 caveat in the macaroon, or if the caveat was not found in the macaroon.
40 """
41 prefix = key + " = "
42 result = None # type: Optional[str]
43 for caveat in macaroon.caveats:
44 if not caveat.caveat_id.startswith(prefix):
45 continue
46
47 val = caveat.caveat_id[len(prefix) :]
48
49 if result is None:
50 # first time we found this caveat: record the value
51 result = val
52 elif val != result:
53 # on subsequent occurrences, raise if the value is different.
54 raise MacaroonVerificationFailedException(
55 "Conflicting values for caveat " + key
56 )
57
58 if result is not None:
59 return result
60
61 # If the caveat is not there, we raise a MacaroonVerificationFailedException.
62 # Note that it is insecure to generate a macaroon without all the caveats you
63 # might need (because there is nothing stopping people from adding extra caveats),
64 # so if the caveat isn't there, something odd must be going on.
65 raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
66
67
68 def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
69 """Make a macaroon verifier which accepts 'time' caveats
70
71 Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
72 the given macaroon verifier.
73
74 Args:
75 v: the macaroon verifier
76 get_time_ms: a callable which will return the timestamp after which the caveat
77 should be considered expired. Normally the current time.
78 """
79
80 def verify_expiry_caveat(caveat: str):
81 time_msec = get_time_ms()
82 prefix = "time < "
83 if not caveat.startswith(prefix):
84 return False
85 expiry = int(caveat[len(prefix) :])
86 return time_msec < expiry
87
88 v.satisfy_general(verify_expiry_caveat)
66 from synapse.federation.units import Edu
77 from synapse.rest import admin
88 from synapse.rest.client.v1 import login, room
9 from synapse.util.retryutils import NotRetryingDestination
910
1011 from tests.test_utils import event_injection, make_awaitable
1112 from tests.unittest import FederatingHomeserverTestCase, override_config
4849 else:
4950 data = json_cb()
5051 self.failed_pdus.extend(data["pdus"])
51 raise IOError("Failed to connect because this is a test!")
52 raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
5253
5354 def get_destination_room(self, room: str, destination: str = "host2") -> dict:
5455 """
0 -----BEGIN PRIVATE KEY-----
1 MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
2 Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
3 P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
4 -----END PRIVATE KEY-----
0 -----BEGIN PUBLIC KEY-----
1 MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
2 QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
3 -----END PUBLIC KEY-----
6767 v.verify(macaroon, self.hs.config.macaroon_secret_key)
6868
6969 def test_short_term_login_token_gives_user_id(self):
70 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
71 user_id = self.get_success(
72 self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
73 )
74 self.assertEqual("a_user", user_id)
70 token = self.macaroon_generator.generate_short_term_login_token(
71 "a_user", "", 5000
72 )
73 res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
74 self.assertEqual("a_user", res.user_id)
75 self.assertEqual("", res.auth_provider_id)
7576
7677 # when we advance the clock, the token should be rejected
7778 self.reactor.advance(6)
7879 self.get_failure(
79 self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
80 self.auth_handler.validate_short_term_login_token(token),
8081 AuthError,
8182 )
8283
84 def test_short_term_login_token_gives_auth_provider(self):
85 token = self.macaroon_generator.generate_short_term_login_token(
86 "a_user", auth_provider_id="my_idp"
87 )
88 res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
89 self.assertEqual("a_user", res.user_id)
90 self.assertEqual("my_idp", res.auth_provider_id)
91
8392 def test_short_term_login_token_cannot_replace_user_id(self):
84 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
93 token = self.macaroon_generator.generate_short_term_login_token(
94 "a_user", "", 5000
95 )
8596 macaroon = pymacaroons.Macaroon.deserialize(token)
8697
87 user_id = self.get_success(
88 self.auth_handler.validate_short_term_login_token_and_get_user_id(
89 macaroon.serialize()
90 )
91 )
92 self.assertEqual("a_user", user_id)
98 res = self.get_success(
99 self.auth_handler.validate_short_term_login_token(macaroon.serialize())
100 )
101 self.assertEqual("a_user", res.user_id)
93102
94103 # add another "user_id" caveat, which might allow us to override the
95104 # user_id.
96105 macaroon.add_first_party_caveat("user_id = b_user")
97106
98107 self.get_failure(
99 self.auth_handler.validate_short_term_login_token_and_get_user_id(
100 macaroon.serialize()
101 ),
108 self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
102109 AuthError,
103110 )
104111
112119 )
113120
114121 self.get_success(
115 self.auth_handler.validate_short_term_login_token_and_get_user_id(
122 self.auth_handler.validate_short_term_login_token(
116123 self._get_macaroon().serialize()
117124 )
118125 )
134141 return_value=make_awaitable(self.large_number_of_users)
135142 )
136143 self.get_failure(
137 self.auth_handler.validate_short_term_login_token_and_get_user_id(
144 self.auth_handler.validate_short_term_login_token(
138145 self._get_macaroon().serialize()
139146 ),
140147 ResourceLimitError,
158165 ResourceLimitError,
159166 )
160167 self.get_failure(
161 self.auth_handler.validate_short_term_login_token_and_get_user_id(
168 self.auth_handler.validate_short_term_login_token(
162169 self._get_macaroon().serialize()
163170 ),
164171 ResourceLimitError,
174181 )
175182 )
176183 self.get_success(
177 self.auth_handler.validate_short_term_login_token_and_get_user_id(
184 self.auth_handler.validate_short_term_login_token(
178185 self._get_macaroon().serialize()
179186 )
180187 )
196203 return_value=make_awaitable(self.small_number_of_users)
197204 )
198205 self.get_success(
199 self.auth_handler.validate_short_term_login_token_and_get_user_id(
206 self.auth_handler.validate_short_term_login_token(
200207 self._get_macaroon().serialize()
201208 )
202209 )
203210
204211 def _get_macaroon(self):
205 token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
212 token = self.macaroon_generator.generate_short_term_login_token(
213 "user_a", "", 5000
214 )
206215 return pymacaroons.Macaroon.deserialize(token)
6565
6666 # check that the auth handler got called as expected
6767 auth_handler.complete_sso_login.assert_called_once_with(
68 "@test_user:test", request, "redirect_uri", None, new_user=True
68 "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
6969 )
7070
7171 def test_map_cas_user_to_existing_user(self):
8888
8989 # check that the auth handler got called as expected
9090 auth_handler.complete_sso_login.assert_called_once_with(
91 "@test_user:test", request, "redirect_uri", None, new_user=False
91 "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
9292 )
9393
9494 # Subsequent calls should map to the same mxid.
9797 self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
9898 )
9999 auth_handler.complete_sso_login.assert_called_once_with(
100 "@test_user:test", request, "redirect_uri", None, new_user=False
100 "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
101101 )
102102
103103 def test_map_cas_user_to_invalid_localpart(self):
115115
116116 # check that the auth handler got called as expected
117117 auth_handler.complete_sso_login.assert_called_once_with(
118 "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
118 "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
119119 )
120120
121121 @override_config(
159159
160160 # check that the auth handler got called as expected
161161 auth_handler.complete_sso_login.assert_called_once_with(
162 "@test_user:test", request, "redirect_uri", None, new_user=True
162 "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
163163 )
164164
165165
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import json
15 from typing import Optional
15 import os
1616 from urllib.parse import parse_qs, urlparse
1717
1818 from mock import ANY, Mock, patch
2222 from synapse.handlers.sso import MappingException
2323 from synapse.server import HomeServer
2424 from synapse.types import UserID
25 from synapse.util.macaroons import get_value_from_macaroon
2526
2627 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
2728 from tests.unittest import HomeserverTestCase, override_config
4950 JWKS_URI = ISSUER + ".well-known/jwks.json"
5051
5152 # config for common cases
52 COMMON_CONFIG = {
53 DEFAULT_CONFIG = {
54 "enabled": True,
55 "client_id": CLIENT_ID,
56 "client_secret": CLIENT_SECRET,
57 "issuer": ISSUER,
58 "scopes": SCOPES,
59 "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
60 }
61
62 # extends the default config with explicit OAuth2 endpoints instead of using discovery
63 EXPLICIT_ENDPOINT_CONFIG = {
64 **DEFAULT_CONFIG,
5365 "discover": False,
5466 "authorization_endpoint": AUTHORIZATION_ENDPOINT,
5567 "token_endpoint": TOKEN_ENDPOINT,
106118 return {"keys": []}
107119
108120
121 def _key_file_path() -> str:
122 """path to a file containing the private half of a test key"""
123
124 # this key was generated with:
125 # openssl ecparam -name prime256v1 -genkey -noout |
126 # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
127 #
128 # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
129 # that's what Apple use, and we want to be sure that we work with Apple's keys.
130 #
131 # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
132 # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
133 # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
134 # really need to care about any of that.)
135 return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
136
137
138 def _public_key_file_path() -> str:
139 """path to a file containing the public half of a test key"""
140 # this was generated with:
141 # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
142 #
143 # See above about where oidc_test_key.p8 came from
144 return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
145
146
109147 class OidcHandlerTestCase(HomeserverTestCase):
110148 if not HAS_OIDC:
111149 skip = "requires OIDC"
113151 def default_config(self):
114152 config = super().default_config()
115153 config["public_baseurl"] = BASE_URL
116 oidc_config = {
117 "enabled": True,
118 "client_id": CLIENT_ID,
119 "client_secret": CLIENT_SECRET,
120 "issuer": ISSUER,
121 "scopes": SCOPES,
122 "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
123 }
124
125 # Update this config with what's in the default config so that
126 # override_config works as expected.
127 oidc_config.update(config.get("oidc_config", {}))
128 config["oidc_config"] = oidc_config
129
130154 return config
131155
132156 def make_homeserver(self, reactor, clock):
169193 self.render_error.reset_mock()
170194 return args
171195
196 @override_config({"oidc_config": DEFAULT_CONFIG})
172197 def test_config(self):
173198 """Basic config correctly sets up the callback URL and client auth correctly."""
174199 self.assertEqual(self.provider._callback_url, CALLBACK_URL)
175200 self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
176201 self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
177202
178 @override_config({"oidc_config": {"discover": True}})
203 @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
179204 def test_discovery(self):
180205 """The handler should discover the endpoints from OIDC discovery document."""
181206 # This would throw if some metadata were invalid
194219 self.get_success(self.provider.load_metadata())
195220 self.http_client.get_json.assert_not_called()
196221
197 @override_config({"oidc_config": COMMON_CONFIG})
222 @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
198223 def test_no_discovery(self):
199224 """When discovery is disabled, it should not try to load from discovery document."""
200225 self.get_success(self.provider.load_metadata())
201226 self.http_client.get_json.assert_not_called()
202227
203 @override_config({"oidc_config": COMMON_CONFIG})
228 @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
204229 def test_load_jwks(self):
205230 """JWKS loading is done once (then cached) if used."""
206231 jwks = self.get_success(self.provider.load_jwks())
235260 self.http_client.get_json.assert_not_called()
236261 self.assertEqual(jwks, {"keys": []})
237262
263 @override_config({"oidc_config": DEFAULT_CONFIG})
238264 def test_validate_config(self):
239265 """Provider metadatas are extensively validated."""
240266 h = self.provider
317343 # Shouldn't raise with a valid userinfo, even without jwks
318344 force_load_metadata()
319345
320 @override_config({"oidc_config": {"skip_verification": True}})
346 @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
321347 def test_skip_verification(self):
322348 """Provider metadata validation can be disabled by config."""
323349 with self.metadata_edit({"issuer": "http://insecure"}):
324350 # This should not throw
325351 get_awaitable_result(self.provider.load_metadata())
326352
353 @override_config({"oidc_config": DEFAULT_CONFIG})
327354 def test_redirect_request(self):
328355 """The redirect request has the right arguments & generates a valid session cookie."""
329356 req = Mock(spec=["cookies"])
359386 self.assertEqual(name, b"oidc_session")
360387
361388 macaroon = pymacaroons.Macaroon.deserialize(cookie)
362 state = self.handler._token_generator._get_value_from_macaroon(
363 macaroon, "state"
364 )
365 nonce = self.handler._token_generator._get_value_from_macaroon(
366 macaroon, "nonce"
367 )
368 redirect = self.handler._token_generator._get_value_from_macaroon(
369 macaroon, "client_redirect_url"
370 )
389 state = get_value_from_macaroon(macaroon, "state")
390 nonce = get_value_from_macaroon(macaroon, "nonce")
391 redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
371392
372393 self.assertEqual(params["state"], [state])
373394 self.assertEqual(params["nonce"], [nonce])
374395 self.assertEqual(redirect, "http://client/redirect")
375396
397 @override_config({"oidc_config": DEFAULT_CONFIG})
376398 def test_callback_error(self):
377399 """Errors from the provider returned in the callback are displayed."""
378400 request = Mock(args={})
384406 self.get_success(self.handler.handle_oidc_callback(request))
385407 self.assertRenderedError("invalid_client", "some description")
386408
409 @override_config({"oidc_config": DEFAULT_CONFIG})
387410 def test_callback(self):
388411 """Code callback works and display errors if something went wrong.
389412
433456 self.get_success(self.handler.handle_oidc_callback(request))
434457
435458 auth_handler.complete_sso_login.assert_called_once_with(
436 expected_user_id, request, client_redirect_url, None, new_user=True
459 expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
437460 )
438461 self.provider._exchange_code.assert_called_once_with(code)
439462 self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
464487 self.get_success(self.handler.handle_oidc_callback(request))
465488
466489 auth_handler.complete_sso_login.assert_called_once_with(
467 expected_user_id, request, client_redirect_url, None, new_user=False
490 expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
468491 )
469492 self.provider._exchange_code.assert_called_once_with(code)
470493 self.provider._parse_id_token.assert_not_called()
485508 self.get_success(self.handler.handle_oidc_callback(request))
486509 self.assertRenderedError("invalid_request")
487510
511 @override_config({"oidc_config": DEFAULT_CONFIG})
488512 def test_callback_session(self):
489513 """The callback verifies the session presence and validity"""
490514 request = Mock(spec=["args", "getCookie", "cookies"])
527551 self.get_success(self.handler.handle_oidc_callback(request))
528552 self.assertRenderedError("invalid_request")
529553
530 @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
554 @override_config(
555 {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
556 )
531557 def test_exchange_code(self):
532558 """Code exchange behaves correctly and handles various error scenarios."""
533559 token = {"type": "bearer"}
612638 @override_config(
613639 {
614640 "oidc_config": {
641 "enabled": True,
642 "client_id": CLIENT_ID,
643 "issuer": ISSUER,
644 "client_auth_method": "client_secret_post",
645 "client_secret_jwt_key": {
646 "key_file": _key_file_path(),
647 "jwt_header": {"alg": "ES256", "kid": "ABC789"},
648 "jwt_payload": {"iss": "DEFGHI"},
649 },
650 }
651 }
652 )
653 def test_exchange_code_jwt_key(self):
654 """Test that code exchange works with a JWK client secret."""
655 from authlib.jose import jwt
656
657 token = {"type": "bearer"}
658 self.http_client.request = simple_async_mock(
659 return_value=FakeResponse(
660 code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
661 )
662 )
663 code = "code"
664
665 # advance the clock a bit before we start, so we aren't working with zero
666 # timestamps.
667 self.reactor.advance(1000)
668 start_time = self.reactor.seconds()
669 ret = self.get_success(self.provider._exchange_code(code))
670
671 self.assertEqual(ret, token)
672
673 # the request should have hit the token endpoint
674 kwargs = self.http_client.request.call_args[1]
675 self.assertEqual(kwargs["method"], "POST")
676 self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
677
678 # the client secret provided to the should be a jwt which can be checked with
679 # the public key
680 args = parse_qs(kwargs["data"].decode("utf-8"))
681 secret = args["client_secret"][0]
682 with open(_public_key_file_path()) as f:
683 key = f.read()
684 claims = jwt.decode(secret, key)
685 self.assertEqual(claims.header["kid"], "ABC789")
686 self.assertEqual(claims["aud"], ISSUER)
687 self.assertEqual(claims["iss"], "DEFGHI")
688 self.assertEqual(claims["sub"], CLIENT_ID)
689 self.assertEqual(claims["iat"], start_time)
690 self.assertGreater(claims["exp"], start_time)
691
692 # check the rest of the POSTed data
693 self.assertEqual(args["grant_type"], ["authorization_code"])
694 self.assertEqual(args["code"], [code])
695 self.assertEqual(args["client_id"], [CLIENT_ID])
696 self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
697
698 @override_config(
699 {
700 "oidc_config": {
701 "enabled": True,
702 "client_id": CLIENT_ID,
703 "issuer": ISSUER,
704 "client_auth_method": "none",
705 }
706 }
707 )
708 def test_exchange_code_no_auth(self):
709 """Test that code exchange works with no client secret."""
710 token = {"type": "bearer"}
711 self.http_client.request = simple_async_mock(
712 return_value=FakeResponse(
713 code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
714 )
715 )
716 code = "code"
717 ret = self.get_success(self.provider._exchange_code(code))
718
719 self.assertEqual(ret, token)
720
721 # the request should have hit the token endpoint
722 kwargs = self.http_client.request.call_args[1]
723 self.assertEqual(kwargs["method"], "POST")
724 self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
725
726 # check the POSTed data
727 args = parse_qs(kwargs["data"].decode("utf-8"))
728 self.assertEqual(args["grant_type"], ["authorization_code"])
729 self.assertEqual(args["code"], [code])
730 self.assertEqual(args["client_id"], [CLIENT_ID])
731 self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
732
733 @override_config(
734 {
735 "oidc_config": {
736 **DEFAULT_CONFIG,
615737 "user_mapping_provider": {
616738 "module": __name__ + ".TestMappingProviderExtra"
617 }
739 },
618740 }
619741 }
620742 )
650772
651773 auth_handler.complete_sso_login.assert_called_once_with(
652774 "@foo:test",
775 "oidc",
653776 request,
654777 client_redirect_url,
655778 {"phone": "1234567"},
656779 new_user=True,
657780 )
658781
782 @override_config({"oidc_config": DEFAULT_CONFIG})
659783 def test_map_userinfo_to_user(self):
660784 """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
661785 auth_handler = self.hs.get_auth_handler()
667791 }
668792 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
669793 auth_handler.complete_sso_login.assert_called_once_with(
670 "@test_user:test", ANY, ANY, None, new_user=True
794 "@test_user:test", "oidc", ANY, ANY, None, new_user=True
671795 )
672796 auth_handler.complete_sso_login.reset_mock()
673797
678802 }
679803 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
680804 auth_handler.complete_sso_login.assert_called_once_with(
681 "@test_user_2:test", ANY, ANY, None, new_user=True
805 "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
682806 )
683807 auth_handler.complete_sso_login.reset_mock()
684808
696820 "Mapping provider does not support de-duplicating Matrix IDs",
697821 )
698822
699 @override_config({"oidc_config": {"allow_existing_users": True}})
823 @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
700824 def test_map_userinfo_to_existing_user(self):
701825 """Existing users can log in with OpenID Connect when allow_existing_users is True."""
702826 store = self.hs.get_datastore()
715839 }
716840 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
717841 auth_handler.complete_sso_login.assert_called_once_with(
718 user.to_string(), ANY, ANY, None, new_user=False
842 user.to_string(), "oidc", ANY, ANY, None, new_user=False
719843 )
720844 auth_handler.complete_sso_login.reset_mock()
721845
722846 # Subsequent calls should map to the same mxid.
723847 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
724848 auth_handler.complete_sso_login.assert_called_once_with(
725 user.to_string(), ANY, ANY, None, new_user=False
849 user.to_string(), "oidc", ANY, ANY, None, new_user=False
726850 )
727851 auth_handler.complete_sso_login.reset_mock()
728852
737861 }
738862 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
739863 auth_handler.complete_sso_login.assert_called_once_with(
740 user.to_string(), ANY, ANY, None, new_user=False
864 user.to_string(), "oidc", ANY, ANY, None, new_user=False
741865 )
742866 auth_handler.complete_sso_login.reset_mock()
743867
773897
774898 self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
775899 auth_handler.complete_sso_login.assert_called_once_with(
776 "@TEST_USER_2:test", ANY, ANY, None, new_user=False
777 )
778
900 "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
901 )
902
903 @override_config({"oidc_config": DEFAULT_CONFIG})
779904 def test_map_userinfo_to_invalid_localpart(self):
780905 """If the mapping provider generates an invalid localpart it should be rejected."""
781906 self.get_success(
786911 @override_config(
787912 {
788913 "oidc_config": {
914 **DEFAULT_CONFIG,
789915 "user_mapping_provider": {
790916 "module": __name__ + ".TestMappingProviderFailures"
791 }
917 },
792918 }
793919 }
794920 )
809935
810936 # test_user is already taken, so test_user1 gets registered instead.
811937 auth_handler.complete_sso_login.assert_called_once_with(
812 "@test_user1:test", ANY, ANY, None, new_user=True
938 "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
813939 )
814940 auth_handler.complete_sso_login.reset_mock()
815941
833959 "mapping_error", "Unable to generate a Matrix ID from the SSO response"
834960 )
835961
962 @override_config({"oidc_config": DEFAULT_CONFIG})
836963 def test_empty_localpart(self):
837964 """Attempts to map onto an empty localpart should be rejected."""
838965 userinfo = {
845972 @override_config(
846973 {
847974 "oidc_config": {
975 **DEFAULT_CONFIG,
848976 "user_mapping_provider": {
849977 "config": {"localpart_template": "{{ user.username }}"}
850 }
978 },
851979 }
852980 }
853981 )
865993 state: str,
866994 nonce: str,
867995 client_redirect_url: str,
868 ui_auth_session_id: Optional[str] = None,
996 ui_auth_session_id: str = "",
869997 ) -> str:
870998 from synapse.handlers.oidc_handler import OidcSessionData
871999
9081036 idp_id="oidc",
9091037 nonce="nonce",
9101038 client_redirect_url=client_redirect_url,
1039 ui_auth_session_id="",
9111040 ),
9121041 )
9131042 request = _build_callback_request("code", state, session)
516516
517517 self.assertTrue(requester.shadow_banned)
518518
519 def test_spam_checker_receives_sso_type(self):
520 """Test rejecting registration based on SSO type"""
521
522 class BanBadIdPUser:
523 def check_registration_for_spam(
524 self, email_threepid, username, request_info, auth_provider_id=None
525 ):
526 # Reject any user coming from CAS and whose username contains profanity
527 if auth_provider_id == "cas" and "flimflob" in username:
528 return RegistrationBehaviour.DENY
529 return RegistrationBehaviour.ALLOW
530
531 # Configure a spam checker that denies a certain user on a specific IdP
532 spam_checker = self.hs.get_spam_checker()
533 spam_checker.spam_checkers = [BanBadIdPUser()]
534
535 f = self.get_failure(
536 self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
537 SynapseError,
538 )
539 exception = f.value
540
541 # We return 429 from the spam checker for denied registrations
542 self.assertIsInstance(exception, SynapseError)
543 self.assertEqual(exception.code, 429)
544
545 # Check the same username can register using SAML
546 self.get_success(
547 self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
548 )
549
519550 async def get_or_create_user(
520551 self, requester, localpart, displayname, password_hash=None
521552 ):
130130
131131 # check that the auth handler got called as expected
132132 auth_handler.complete_sso_login.assert_called_once_with(
133 "@test_user:test", request, "redirect_uri", None, new_user=True
133 "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
134134 )
135135
136136 @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
156156
157157 # check that the auth handler got called as expected
158158 auth_handler.complete_sso_login.assert_called_once_with(
159 "@test_user:test", request, "", None, new_user=False
159 "@test_user:test", "saml", request, "", None, new_user=False
160160 )
161161
162162 # Subsequent calls should map to the same mxid.
165165 self.handler._handle_authn_response(request, saml_response, "")
166166 )
167167 auth_handler.complete_sso_login.assert_called_once_with(
168 "@test_user:test", request, "", None, new_user=False
168 "@test_user:test", "saml", request, "", None, new_user=False
169169 )
170170
171171 def test_map_saml_response_to_invalid_localpart(self):
213213
214214 # test_user is already taken, so test_user1 gets registered instead.
215215 auth_handler.complete_sso_login.assert_called_once_with(
216 "@test_user1:test", request, "", None, new_user=True
216 "@test_user1:test", "saml", request, "", None, new_user=True
217217 )
218218 auth_handler.complete_sso_login.reset_mock()
219219
309309
310310 # check that the auth handler got called as expected
311311 auth_handler.complete_sso_login.assert_called_once_with(
312 "@test_user:test", request, "redirect_uri", None, new_user=True
312 "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
313313 )
314314
315315
1515
1616 from mock import Mock
1717
18 from netaddr import IPSet
19
20 from twisted.internet.error import DNSLookupError
1821 from twisted.python.failure import Failure
19 from twisted.web.client import ResponseDone
22 from twisted.test.proto_helpers import AccumulatingProtocol
23 from twisted.web.client import Agent, ResponseDone
2024 from twisted.web.iweb import UNKNOWN_LENGTH
2125
22 from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
23
26 from synapse.api.errors import SynapseError
27 from synapse.http.client import (
28 BlacklistingAgentWrapper,
29 BlacklistingReactorWrapper,
30 BodyExceededMaxSize,
31 read_body_with_max_size,
32 )
33
34 from tests.server import FakeTransport, get_clock
2435 from tests.unittest import TestCase
2536
2637
118129
119130 # The data is never consumed.
120131 self.assertEqual(result.getvalue(), b"")
132
133
134 class BlacklistingAgentTest(TestCase):
135 def setUp(self):
136 self.reactor, self.clock = get_clock()
137
138 self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
139 self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
140 self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
141
142 # Configure the reactor's DNS resolver.
143 for (domain, ip) in (
144 (self.safe_domain, self.safe_ip),
145 (self.unsafe_domain, self.unsafe_ip),
146 (self.allowed_domain, self.allowed_ip),
147 ):
148 self.reactor.lookups[domain.decode()] = ip.decode()
149 self.reactor.lookups[ip.decode()] = ip.decode()
150
151 self.ip_whitelist = IPSet([self.allowed_ip.decode()])
152 self.ip_blacklist = IPSet(["5.0.0.0/8"])
153
154 def test_reactor(self):
155 """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
156 agent = Agent(
157 BlacklistingReactorWrapper(
158 self.reactor,
159 ip_whitelist=self.ip_whitelist,
160 ip_blacklist=self.ip_blacklist,
161 ),
162 )
163
164 # The unsafe domains and IPs should be rejected.
165 for domain in (self.unsafe_domain, self.unsafe_ip):
166 self.failureResultOf(
167 agent.request(b"GET", b"http://" + domain), DNSLookupError
168 )
169
170 # The safe domains IPs should be accepted.
171 for domain in (
172 self.safe_domain,
173 self.allowed_domain,
174 self.safe_ip,
175 self.allowed_ip,
176 ):
177 d = agent.request(b"GET", b"http://" + domain)
178
179 # Grab the latest TCP connection.
180 (
181 host,
182 port,
183 client_factory,
184 _timeout,
185 _bindAddress,
186 ) = self.reactor.tcpClients[-1]
187
188 # Make the connection and pump data through it.
189 client = client_factory.buildProtocol(None)
190 server = AccumulatingProtocol()
191 server.makeConnection(FakeTransport(client, self.reactor))
192 client.makeConnection(FakeTransport(server, self.reactor))
193 client.dataReceived(
194 b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
195 )
196
197 response = self.successResultOf(d)
198 self.assertEqual(response.code, 200)
199
200 def test_agent(self):
201 """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
202 agent = BlacklistingAgentWrapper(
203 Agent(self.reactor),
204 ip_whitelist=self.ip_whitelist,
205 ip_blacklist=self.ip_blacklist,
206 )
207
208 # The unsafe IPs should be rejected.
209 self.failureResultOf(
210 agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
211 )
212
213 # The safe and unsafe domains and safe IPs should be accepted.
214 for domain in (
215 self.safe_domain,
216 self.unsafe_domain,
217 self.allowed_domain,
218 self.safe_ip,
219 self.allowed_ip,
220 ):
221 d = agent.request(b"GET", b"http://" + domain)
222
223 # Grab the latest TCP connection.
224 (
225 host,
226 port,
227 client_factory,
228 _timeout,
229 _bindAddress,
230 ) = self.reactor.tcpClients[-1]
231
232 # Make the connection and pump data through it.
233 client = client_factory.buildProtocol(None)
234 server = AccumulatingProtocol()
235 server.makeConnection(FakeTransport(client, self.reactor))
236 client.makeConnection(FakeTransport(server, self.reactor))
237 client.dataReceived(
238 b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
239 )
240
241 response = self.successResultOf(d)
242 self.assertEqual(response.code, 200)
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414 import logging
15 from typing import Any, Callable, Dict, List, Optional, Tuple
16
17 import attr
15 from typing import Any, Callable, Dict, List, Optional, Tuple, Type
1816
1917 from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
2018 from twisted.internet.protocol import Protocol
2119 from twisted.internet.task import LoopingCall
2220 from twisted.web.http import HTTPChannel
2321 from twisted.web.resource import Resource
22 from twisted.web.server import Request, Site
2423
2524 from synapse.app.generic_worker import (
2625 GenericWorkerReplicationHandler,
3130 from synapse.replication.http import ReplicationRestResource
3231 from synapse.replication.tcp.handler import ReplicationCommandHandler
3332 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
34 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
33 from synapse.replication.tcp.resource import (
34 ReplicationStreamProtocolFactory,
35 ServerReplicationStreamProtocol,
36 )
3537 from synapse.server import HomeServer
3638 from synapse.util import Clock
3739
5860 # build a replication server
5961 server_factory = ReplicationStreamProtocolFactory(hs)
6062 self.streamer = hs.get_replication_streamer()
61 self.server = server_factory.buildProtocol(None)
63 self.server = server_factory.buildProtocol(
64 None
65 ) # type: ServerReplicationStreamProtocol
6266
6367 # Make a new HomeServer object for the worker
6468 self.reactor.lookups["testserv"] = "1.2.3.4"
151155 # Set up client side protocol
152156 client_protocol = client_factory.buildProtocol(None)
153157
154 request_factory = OneShotRequestFactory()
155
156158 # Set up the server side protocol
157 channel = _PushHTTPChannel(self.reactor)
158 channel.requestFactory = request_factory
159 channel.site = self.site
159 channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
160160
161161 # Connect client to server and vice versa.
162162 client_to_server_transport = FakeTransport(
178178 server_to_client_transport.loseConnection()
179179 client_to_server_transport.loseConnection()
180180
181 return request_factory.request
181 return channel.request
182182
183183 def assert_request_is_get_repl_stream_updates(
184184 self, request: SynapseRequest, stream_name: str
187187 fetching updates for given stream.
188188 """
189189
190 path = request.path # type: bytes # type: ignore
190191 self.assertRegex(
191 request.path,
192 path,
192193 br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
193194 % (stream_name.encode("ascii"),),
194195 )
231232 if self.hs.config.redis.redis_enabled:
232233 # Handle attempts to connect to fake redis server.
233234 self.reactor.add_tcp_client_callback(
234 "localhost",
235 b"localhost",
235236 6379,
236237 self.connect_any_redis_attempts,
237238 )
386387 # Set up client side protocol
387388 client_protocol = client_factory.buildProtocol(None)
388389
389 request_factory = OneShotRequestFactory()
390
391390 # Set up the server side protocol
392 channel = _PushHTTPChannel(self.reactor)
393 channel.requestFactory = request_factory
394 channel.site = self._hs_to_site[hs]
391 channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
395392
396393 # Connect client to server and vice versa.
397394 client_to_server_transport = FakeTransport(
417414 clients = self.reactor.tcpClients
418415 while clients:
419416 (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
420 self.assertEqual(host, "localhost")
417 self.assertEqual(host, b"localhost")
421418 self.assertEqual(port, 6379)
422419
423420 client_protocol = client_factory.buildProtocol(None)
447444 await super().on_rdata(stream_name, instance_name, token, rows)
448445 for r in rows:
449446 self.received_rdata_rows.append((stream_name, token, r))
450
451
452 @attr.s()
453 class OneShotRequestFactory:
454 """A simple request factory that generates a single `SynapseRequest` and
455 stores it for future use. Can only be used once.
456 """
457
458 request = attr.ib(default=None)
459
460 def __call__(self, *args, **kwargs):
461 assert self.request is None
462
463 self.request = SynapseRequest(*args, **kwargs)
464 return self.request
465447
466448
467449 class _PushHTTPChannel(HTTPChannel):
474456 makes it very hard to test.
475457 """
476458
477 def __init__(self, reactor: IReactorTime):
459 def __init__(
460 self, reactor: IReactorTime, request_factory: Type[Request], site: Site
461 ):
478462 super().__init__()
479463 self.reactor = reactor
464 self.requestFactory = request_factory
465 self.site = site
480466
481467 self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
482468
501487 # `handle_http_replication_attempt`.
502488 request.responseHeaders.setRawHeaders(b"connection", [b"close"])
503489 return False
490
491 def requestDone(self, request):
492 # Store the request for inspection.
493 self.request = request
494 super().requestDone(request)
504495
505496
506497 class _PullToPushProducer:
588579
589580 class FakeRedisPubSubProtocol(Protocol):
590581 """A connection from a client talking to the fake Redis server."""
582
583 transport = None # type: Optional[FakeTransport]
591584
592585 def __init__(self, server: FakeRedisPubSubServer):
593586 self._server = server
633626
634627 def send(self, msg):
635628 """Send a message back to the client."""
629 assert self.transport is not None
630
636631 raw = self.encode(msg).encode("utf-8")
637632
638633 self.transport.write(raw)
1616
1717 from synapse.app.generic_worker import GenericWorkerServer
1818 from synapse.replication.tcp.commands import FederationAckCommand
19 from synapse.replication.tcp.protocol import AbstractConnection
19 from synapse.replication.tcp.protocol import IReplicationConnection
2020 from synapse.replication.tcp.streams.federation import FederationStream
2121
2222 from tests.unittest import HomeserverTestCase
5050 """
5151 rch = self.hs.get_tcp_replication()
5252
53 # wire up the ReplicationCommandHandler to a mock connection
54 mock_connection = mock.Mock(spec=AbstractConnection)
53 # wire up the ReplicationCommandHandler to a mock connection, which needs
54 # to implement IReplicationConnection. (Note that Mock doesn't understand
55 # interfaces, but casing an interface to a list gives the attributes.)
56 mock_connection = mock.Mock(spec=list(IReplicationConnection))
5557 rch.new_connection(mock_connection)
5658
5759 # tell it it received an RDATA row
436436 channel = self.make_request("GET", "/_matrix/client/r0/login")
437437 self.assertEqual(channel.code, 200, channel.result)
438438
439 expected_flows = [
440 {"type": "m.login.cas"},
441 {"type": "m.login.sso"},
442 {"type": "m.login.token"},
443 {"type": "m.login.password"},
444 ] + ADDITIONAL_LOGIN_FLOWS
445
446 self.assertCountEqual(channel.json_body["flows"], expected_flows)
439 expected_flow_types = [
440 "m.login.cas",
441 "m.login.sso",
442 "m.login.token",
443 "m.login.password",
444 ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
445
446 self.assertCountEqual(
447 [f["type"] for f in channel.json_body["flows"]], expected_flow_types
448 )
447449
448450 @override_config({"experimental_features": {"msc2858_enabled": True}})
449451 def test_get_msc2858_login_flows(self):
635637 )
636638 self.assertEqual(channel.code, 400, channel.result)
637639
638 def test_client_idp_redirect_msc2858_disabled(self):
639 """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
640 channel = self._make_sso_redirect_request(True, "oidc")
641 self.assertEqual(channel.code, 400, channel.result)
642 self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
643
644 @override_config({"experimental_features": {"msc2858_enabled": True}})
645640 def test_client_idp_redirect_to_unknown(self):
646641 """If the client tries to pick an unknown IdP, return a 404"""
647 channel = self._make_sso_redirect_request(True, "xxx")
642 channel = self._make_sso_redirect_request(False, "xxx")
648643 self.assertEqual(channel.code, 404, channel.result)
649644 self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
650645
651 @override_config({"experimental_features": {"msc2858_enabled": True}})
652646 def test_client_idp_redirect_to_oidc(self):
653647 """If the client pick a known IdP, redirect to it"""
648 channel = self._make_sso_redirect_request(False, "oidc")
649 self.assertEqual(channel.code, 302, channel.result)
650 oidc_uri = channel.headers.getRawHeaders("Location")[0]
651 oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
652
653 # it should redirect us to the auth page of the OIDC server
654 self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
655
656 @override_config({"experimental_features": {"msc2858_enabled": True}})
657 def test_client_msc2858_redirect_to_oidc(self):
658 """Test the unstable API"""
654659 channel = self._make_sso_redirect_request(True, "oidc")
655660 self.assertEqual(channel.code, 302, channel.result)
656661 oidc_uri = channel.headers.getRawHeaders("Location")[0]
658663
659664 # it should redirect us to the auth page of the OIDC server
660665 self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
666
667 def test_client_idp_redirect_msc2858_disabled(self):
668 """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
669 channel = self._make_sso_redirect_request(True, "oidc")
670 self.assertEqual(channel.code, 400, channel.result)
671 self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
661672
662673 def _make_sso_redirect_request(
663674 self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
104104 self.assertEqual(test_body, body)
105105
106106
107 @attr.s
107 @attr.s(slots=True, frozen=True)
108108 class _TestImage:
109109 """An image for testing thumbnailing with the expected results
110110
116116 test should just check for success.
117117 expected_scaled: The expected bytes from scaled thumbnailing, or None if
118118 test should just check for a valid image returned.
119 expected_found: True if the file should exist on the server, or False if
120 a 404 is expected.
119121 """
120122
121123 data = attr.ib(type=bytes)
122124 content_type = attr.ib(type=bytes)
123125 extension = attr.ib(type=bytes)
124 expected_cropped = attr.ib(type=Optional[bytes])
125 expected_scaled = attr.ib(type=Optional[bytes])
126 expected_cropped = attr.ib(type=Optional[bytes], default=None)
127 expected_scaled = attr.ib(type=Optional[bytes], default=None)
126128 expected_found = attr.ib(default=True, type=bool)
127129
128130
152154 ),
153155 ),
154156 ),
157 # small png with transparency.
158 (
159 _TestImage(
160 unhexlify(
161 b"89504e470d0a1a0a0000000d49484452000000010000000101000"
162 b"00000376ef9240000000274524e5300010194fdae0000000a4944"
163 b"4154789c636800000082008177cd72b60000000049454e44ae426"
164 b"082"
165 ),
166 b"image/png",
167 b".png",
168 # Note that we don't check the output since it varies across
169 # different versions of Pillow.
170 ),
171 ),
155172 # small lossless webp
156173 (
157174 _TestImage(
161178 ),
162179 b"image/webp",
163180 b".webp",
164 None,
165 None,
166181 ),
167182 ),
168183 # an empty file
171186 b"",
172187 b"image/gif",
173188 b".gif",
174 None,
175 None,
176 False,
189 expected_found=False,
177190 ),
178191 ),
179192 ],
1515 IReactorPluggableNameResolver,
1616 IReactorTCP,
1717 IResolverSimple,
18 ITransport,
1819 )
1920 from twisted.python.failure import Failure
2021 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
187188
188189 def make_request(
189190 reactor,
190 site: Site,
191 site: Union[Site, FakeSite],
191192 method,
192193 path,
193194 content=b"",
466467 return clock, hs_clock
467468
468469
470 @implementer(ITransport)
469471 @attr.s(cmp=False)
470472 class FakeTransport:
471473 """
117117 r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
118118 self.assertTrue(r == [room2] or r == [room3])
119119
120 @parameterized.expand([(True,), (False,)])
121 def test_auth_difference(self, use_chain_cover_index: bool):
120 def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
122121 room_id = "@ROOM:local"
123122
124123 # The silly auth graph we use to test the auth difference algorithm,
164163 "j": 1,
165164 }
166165
167 # Mark the room as not having a cover index
166 # Mark the room as maybe having a cover index.
168167
169168 def store_room(txn):
170169 self.store.db_pool.simple_insert_txn(
221220 )
222221 )
223222
223 return room_id
224
225 @parameterized.expand([(True,), (False,)])
226 def test_auth_chain_ids(self, use_chain_cover_index: bool):
227 room_id = self._setup_auth_chain(use_chain_cover_index)
228
229 # a and b have the same auth chain.
230 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
231 self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
232 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
233 self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
234 auth_chain_ids = self.get_success(
235 self.store.get_auth_chain_ids(room_id, ["a", "b"])
236 )
237 self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
238
239 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
240 self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
241
242 # d and e have the same auth chain.
243 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
244 self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
245 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
246 self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
247
248 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
249 self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
250
251 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
252 self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
253
254 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
255 self.assertEqual(auth_chain_ids, ["k"])
256
257 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
258 self.assertEqual(auth_chain_ids, ["j"])
259
260 # j and k have no parents.
261 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
262 self.assertEqual(auth_chain_ids, [])
263 auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
264 self.assertEqual(auth_chain_ids, [])
265
266 # More complex input sequences.
267 auth_chain_ids = self.get_success(
268 self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
269 )
270 self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
271
272 auth_chain_ids = self.get_success(
273 self.store.get_auth_chain_ids(room_id, ["h", "i"])
274 )
275 self.assertCountEqual(auth_chain_ids, ["k", "j"])
276
277 # e gets returned even though include_given is false, but it is in the
278 # auth chain of b.
279 auth_chain_ids = self.get_success(
280 self.store.get_auth_chain_ids(room_id, ["b", "e"])
281 )
282 self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
283
284 # Test include_given.
285 auth_chain_ids = self.get_success(
286 self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
287 )
288 self.assertCountEqual(auth_chain_ids, ["i", "j"])
289
290 @parameterized.expand([(True,), (False,)])
291 def test_auth_difference(self, use_chain_cover_index: bool):
292 room_id = self._setup_auth_chain(use_chain_cover_index)
293
224294 # Now actually test that various combinations give the right result:
225295
226296 difference = self.get_success(
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from twisted.internet import defer
16
17 from synapse.api.errors import NotFoundError
15 from synapse.api.errors import NotFoundError, SynapseError
1816 from synapse.rest.client.v1 import room
1917
2018 from tests.unittest import HomeserverTestCase
3230 def prepare(self, reactor, clock, hs):
3331 self.room_id = self.helper.create_room_as(self.user_id)
3432
35 def test_purge(self):
33 self.store = hs.get_datastore()
34 self.storage = self.hs.get_storage()
35
36 def test_purge_history(self):
3637 """
37 Purging a room will delete everything before the topological point.
38 Purging a room history will delete everything before the topological point.
3839 """
3940 # Send four messages to the room
4041 first = self.helper.send(self.room_id, body="test1")
4243 third = self.helper.send(self.room_id, body="test3")
4344 last = self.helper.send(self.room_id, body="test4")
4445
45 store = self.hs.get_datastore()
46 storage = self.hs.get_storage()
47
4846 # Get the topological token
4947 token = self.get_success(
50 store.get_topological_token_for_event(last["event_id"])
48 self.store.get_topological_token_for_event(last["event_id"])
5149 )
5250 token_str = self.get_success(token.to_string(self.hs.get_datastore()))
5351
5452 # Purge everything before this topological token
5553 self.get_success(
56 storage.purge_events.purge_history(self.room_id, token_str, True)
54 self.storage.purge_events.purge_history(self.room_id, token_str, True)
5755 )
5856
5957 # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
6058 # and last is not.
61 self.get_failure(store.get_event(first["event_id"]), NotFoundError)
62 self.get_failure(store.get_event(second["event_id"]), NotFoundError)
63 self.get_failure(store.get_event(third["event_id"]), NotFoundError)
64 self.get_success(store.get_event(last["event_id"]))
59 self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
60 self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
61 self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
62 self.get_success(self.store.get_event(last["event_id"]))
6563
66 def test_purge_wont_delete_extrems(self):
64 def test_purge_history_wont_delete_extrems(self):
6765 """
68 Purging a room will delete everything before the topological point.
66 Purging a room history will delete everything before the topological point.
6967 """
7068 # Send four messages to the room
7169 first = self.helper.send(self.room_id, body="test1")
7371 third = self.helper.send(self.room_id, body="test3")
7472 last = self.helper.send(self.room_id, body="test4")
7573
76 storage = self.hs.get_datastore()
77
7874 # Set the topological token higher than it should be
7975 token = self.get_success(
80 storage.get_topological_token_for_event(last["event_id"])
76 self.store.get_topological_token_for_event(last["event_id"])
8177 )
8278 event = "t{}-{}".format(token.topological + 1, token.stream + 1)
8379
8480 # Purge everything before this topological token
85 purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
86 self.pump()
87 f = self.failureResultOf(purge)
81 f = self.get_failure(
82 self.storage.purge_events.purge_history(self.room_id, event, True),
83 SynapseError,
84 )
8885 self.assertIn("greater than forward", f.value.args[0])
8986
9087 # Try and get the events
91 self.get_success(storage.get_event(first["event_id"]))
92 self.get_success(storage.get_event(second["event_id"]))
93 self.get_success(storage.get_event(third["event_id"]))
94 self.get_success(storage.get_event(last["event_id"]))
88 self.get_success(self.store.get_event(first["event_id"]))
89 self.get_success(self.store.get_event(second["event_id"]))
90 self.get_success(self.store.get_event(third["event_id"]))
91 self.get_success(self.store.get_event(last["event_id"]))
92
93 def test_purge_room(self):
94 """
95 Purging a room will delete everything about it.
96 """
97 # Send four messages to the room
98 first = self.helper.send(self.room_id, body="test1")
99
100 # Get the current room state.
101 state_handler = self.hs.get_state_handler()
102 create_event = self.get_success(
103 state_handler.get_current_state(self.room_id, "m.room.create", "")
104 )
105 self.assertIsNotNone(create_event)
106
107 # Purge everything before this topological token
108 self.get_success(self.storage.purge_events.purge_room(self.room_id))
109
110 # The events aren't found.
111 self.store._invalidate_get_event_cache(create_event.event_id)
112 self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
113 self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
2727 def emit(self, record):
2828 log_entry = self.format(record)
2929 log_level = record.levelname.lower().replace("warning", "warn")
30 self.tx_log.emit(
30 self.tx_log.emit( # type: ignore
3131 twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
3232 )
3333
0 # Copyright 2021 The Matrix.org Foundation C.I.C.
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 from synapse.util.caches.response_cache import ResponseCache
15
16 from tests.server import get_clock
17 from tests.unittest import TestCase
18
19
20 class DeferredCacheTestCase(TestCase):
21 """
22 A TestCase class for ResponseCache.
23
24 The test-case function naming has some logic to it in it's parts, here's some notes about it:
25 wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
26 (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
27 expire: Denotes tests that test expiry after assured existence.
28 (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
29 """
30
31 def setUp(self):
32 self.reactor, self.clock = get_clock()
33
34 def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
35 return ResponseCache(self.clock, name, timeout_ms=ms)
36
37 @staticmethod
38 async def instant_return(o: str) -> str:
39 return o
40
41 async def delayed_return(self, o: str) -> str:
42 await self.clock.sleep(1)
43 return o
44
45 def test_cache_hit(self):
46 cache = self.with_cache("keeping_cache", ms=9001)
47
48 expected_result = "howdy"
49
50 wrap_d = cache.wrap(0, self.instant_return, expected_result)
51
52 self.assertEqual(
53 expected_result,
54 self.successResultOf(wrap_d),
55 "initial wrap result should be the same",
56 )
57 self.assertEqual(
58 expected_result,
59 self.successResultOf(cache.get(0)),
60 "cache should have the result",
61 )
62
63 def test_cache_miss(self):
64 cache = self.with_cache("trashing_cache", ms=0)
65
66 expected_result = "howdy"
67
68 wrap_d = cache.wrap(0, self.instant_return, expected_result)
69
70 self.assertEqual(
71 expected_result,
72 self.successResultOf(wrap_d),
73 "initial wrap result should be the same",
74 )
75 self.assertIsNone(cache.get(0), "cache should not have the result now")
76
77 def test_cache_expire(self):
78 cache = self.with_cache("short_cache", ms=1000)
79
80 expected_result = "howdy"
81
82 wrap_d = cache.wrap(0, self.instant_return, expected_result)
83
84 self.assertEqual(expected_result, self.successResultOf(wrap_d))
85 self.assertEqual(
86 expected_result,
87 self.successResultOf(cache.get(0)),
88 "cache should still have the result",
89 )
90
91 # cache eviction timer is handled
92 self.reactor.pump((2,))
93
94 self.assertIsNone(cache.get(0), "cache should not have the result now")
95
96 def test_cache_wait_hit(self):
97 cache = self.with_cache("neutral_cache")
98
99 expected_result = "howdy"
100
101 wrap_d = cache.wrap(0, self.delayed_return, expected_result)
102 self.assertNoResult(wrap_d)
103
104 # function wakes up, returns result
105 self.reactor.pump((2,))
106
107 self.assertEqual(expected_result, self.successResultOf(wrap_d))
108
109 def test_cache_wait_expire(self):
110 cache = self.with_cache("medium_cache", ms=3000)
111
112 expected_result = "howdy"
113
114 wrap_d = cache.wrap(0, self.delayed_return, expected_result)
115 self.assertNoResult(wrap_d)
116
117 # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
118 self.reactor.pump((1, 1))
119
120 self.assertEqual(expected_result, self.successResultOf(wrap_d))
121 self.assertEqual(
122 expected_result,
123 self.successResultOf(cache.get(0)),
124 "cache should still have the result",
125 )
126
127 # (1 + 1 + 2) > 3.0, cache eviction timer is handled
128 self.reactor.pump((2,))
129
130 self.assertIsNone(cache.get(0), "cache should not have the result now")
188188 [testenv:mypy]
189189 deps =
190190 {[base]deps}
191 # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
192 twisted==20.3.0
193191 extras = all,mypy
194192 commands = mypy