Fix: Getting the value of a StringToString pflag (#874)
* add parsing for stringToString flags
* add logic to return flags default if not val set, add a test
* extract parsing into single func
* add a few more cases
* return nil if unable to parse instead of panicing
* return map[string]interface in order to work with cast.ToStringMap
* mostly copy pflags implementation of the conversion to a stringtostring
Trevor Foster authored 3 years ago
GitHub committed 3 years ago
1082 | 1082 | s = strings.TrimSuffix(s, "]") |
1083 | 1083 | res, _ := readAsCSV(s) |
1084 | 1084 | return cast.ToIntSlice(res) |
1085 | case "stringToString": | |
1086 | return stringToStringConv(flag.ValueString()) | |
1085 | 1087 | default: |
1086 | 1088 | return flag.ValueString() |
1087 | 1089 | } |
1157 | 1159 | s = strings.TrimSuffix(s, "]") |
1158 | 1160 | res, _ := readAsCSV(s) |
1159 | 1161 | return cast.ToIntSlice(res) |
1162 | case "stringToString": | |
1163 | return stringToStringConv(flag.ValueString()) | |
1160 | 1164 | default: |
1161 | 1165 | return flag.ValueString() |
1162 | 1166 | } |
1174 | 1178 | stringReader := strings.NewReader(val) |
1175 | 1179 | csvReader := csv.NewReader(stringReader) |
1176 | 1180 | return csvReader.Read() |
1181 | } | |
1182 | ||
1183 | // mostly copied from pflag's implementation of this operation here https://github.com/spf13/pflag/blob/master/string_to_string.go#L79 | |
1184 | // alterations are: errors are swallowed, map[string]interface{} is returned in order to enable cast.ToStringMap | |
1185 | func stringToStringConv(val string) interface{} { | |
1186 | val = strings.Trim(val, "[]") | |
1187 | // An empty string would cause an empty map | |
1188 | if len(val) == 0 { | |
1189 | return map[string]interface{}{} | |
1190 | } | |
1191 | r := csv.NewReader(strings.NewReader(val)) | |
1192 | ss, err := r.Read() | |
1193 | if err != nil { | |
1194 | return nil | |
1195 | } | |
1196 | out := make(map[string]interface{}, len(ss)) | |
1197 | for _, pair := range ss { | |
1198 | kv := strings.SplitN(pair, "=", 2) | |
1199 | if len(kv) != 2 { | |
1200 | return nil | |
1201 | } | |
1202 | out[kv[0]] = kv[1] | |
1203 | } | |
1204 | return out | |
1177 | 1205 | } |
1178 | 1206 | |
1179 | 1207 | // IsSet checks to see if the key has been set in any of the data locations. |
967 | 967 | flag.Changed = true // hack for pflag usage |
968 | 968 | |
969 | 969 | assert.Equal(t, "testing_mutate", Get("testvalue")) |
970 | } | |
971 | ||
972 | func TestBindPFlagStringToString(t *testing.T) { | |
973 | tests := []struct { | |
974 | Expected map[string]string | |
975 | Value string | |
976 | }{ | |
977 | {map[string]string{}, ""}, | |
978 | {map[string]string{"yo": "hi"}, "yo=hi"}, | |
979 | {map[string]string{"yo": "hi", "oh": "hi=there"}, "yo=hi,oh=hi=there"}, | |
980 | {map[string]string{"yo": ""}, "yo="}, | |
981 | {map[string]string{"yo": "", "oh": "hi=there"}, "yo=,oh=hi=there"}, | |
982 | } | |
983 | ||
984 | v := New() // create independent Viper object | |
985 | defaultVal := map[string]string{} | |
986 | v.SetDefault("stringtostring", defaultVal) | |
987 | ||
988 | for _, testValue := range tests { | |
989 | flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) | |
990 | flagSet.StringToString("stringtostring", testValue.Expected, "test") | |
991 | ||
992 | for _, changed := range []bool{true, false} { | |
993 | flagSet.VisitAll(func(f *pflag.Flag) { | |
994 | f.Value.Set(testValue.Value) | |
995 | f.Changed = changed | |
996 | }) | |
997 | ||
998 | err := v.BindPFlags(flagSet) | |
999 | if err != nil { | |
1000 | t.Fatalf("error binding flag set, %v", err) | |
1001 | } | |
1002 | ||
1003 | type TestMap struct { | |
1004 | StringToString map[string]string | |
1005 | } | |
1006 | val := &TestMap{} | |
1007 | if err := v.Unmarshal(val); err != nil { | |
1008 | t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err) | |
1009 | } | |
1010 | if changed { | |
1011 | assert.Equal(t, testValue.Expected, val.StringToString) | |
1012 | } else { | |
1013 | assert.Equal(t, defaultVal, val.StringToString) | |
1014 | } | |
1015 | } | |
1016 | } | |
970 | 1017 | } |
971 | 1018 | |
972 | 1019 | func TestBoundCaseSensitivity(t *testing.T) { |